Shortcuts

Source code for torchrl.envs.libs.meltingpot

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import importlib

from typing import Dict, List, Mapping, Sequence

import torch

from tensordict import TensorDict, TensorDictBase

from torchrl.data import Categorical, Composite, TensorSpec
from torchrl.envs.common import _EnvWrapper
from torchrl.envs.libs.dm_control import _dmcontrol_to_torchrl_spec_transform
from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType

_has_meltingpot = importlib.util.find_spec("meltingpot") is not None

PLAYER_STR_FORMAT = "player_{index}"
_WORLD_PREFIX = "WORLD."


def _get_envs():
    if not _has_meltingpot:
        raise ImportError("meltingpot is not installed in your virtual environment.")
    from meltingpot.configs import substrates as substrate_configs

    return list(substrate_configs.SUBSTRATES)


def _filter_global_state_from_dict(obs_dict: Dict, world: bool) -> Dict:  # noqa
    return {
        key: value
        for key, value in obs_dict.items()
        if ((_WORLD_PREFIX not in key) if not world else (_WORLD_PREFIX in key))
    }


def _remove_world_observations_from_obs_spec(
    observation_spec: Sequence[Mapping[str, "dm_env.specs.Array"]],  # noqa
) -> Sequence[Mapping[str, "dm_env.specs.Array"]]:  # noqa
    return [
        _filter_global_state_from_dict(agent_obs, world=False)
        for agent_obs in observation_spec
    ]


def _global_state_spec_from_obs_spec(
    observation_spec: Sequence[Mapping[str, "dm_env.specs.Array"]]  # noqa
) -> Mapping[str, "dm_env.specs.Array"]:  # noqa
    # We only look at agent 0 since world entries are the same for all agents
    world_entries = _filter_global_state_from_dict(observation_spec[0], world=True)
    if len(world_entries) != 1 and _WORLD_PREFIX + "RGB" not in world_entries:
        raise ValueError(
            f"Expected only one world entry named {_WORLD_PREFIX}RGB in observation_spec, but got {world_entries}"
        )
    return _remove_world_prefix(world_entries)


def _remove_world_prefix(world_entries: Dict) -> Dict:
    return {key[len(_WORLD_PREFIX) :]: value for key, value in world_entries.items()}


[docs]class MeltingpotWrapper(_EnvWrapper): """Meltingpot environment wrapper. GitHub: https://github.com/google-deepmind/meltingpot Paper: https://arxiv.org/abs/2211.13746 Melting Pot assesses generalization to novel social situations involving both familiar and unfamiliar individuals, and has been designed to test a broad range of social interactions such as: cooperation, competition, deception, reciprocation, trust, stubbornness and so on. Melting Pot offers researchers a set of over 50 multi-agent reinforcement learning substrates (multi-agent games) on which to train agents, and over 256 unique test scenarios on which to evaluate these trained agents. Args: env (``meltingpot.utils.substrates.substrate.Substrate``): the meltingpot substrate to wrap. Keyword Args: max_steps (int, optional): Horizon of the task. Defaults to ``None`` (infinite horizon). Each Meltingpot substrate can be terminating or not. If ``max_steps`` is specified, the scenario is also terminated (and the ``"terminated"`` flag is set) whenever this horizon is reached. Unlike gym's ``TimeLimit`` transform or torchrl's :class:`~torchrl.envs.transforms.StepCounter`, this argument will not set the ``"truncated"`` entry in the tensordict. categorical_actions (bool, optional): if the environment actions are discrete, whether to transform them to categorical or one-hot. Defaults to ``True``. group_map (MarlGroupMapType or Dict[str, List[str]], optional): how to group agents in tensordicts for input/output. By default, they will be all put in one group named ``"agents"``. Otherwise, a group map can be specified or selected from some premade options. See :class:`~torchrl.envs.utils.MarlGroupMapType` for more info. Attributes: group_map (Dict[str, List[str]]): how to group agents in tensordicts for input/output. See :class:`~torchrl.envs.utils.MarlGroupMapType` for more info. agent_names (list of str): names of the agent in the environment agent_names_to_indices_map (Dict[str, int]): dictionary mapping agent names to their index in the environment available_envs (List[str]): the list of the scenarios available to build. .. warning:: Meltingpot returns a single ``done`` flag which does not distinguish between when the env reached ``max_steps`` and termination. If you deem the ``truncation`` signal necessary, set ``max_steps`` to ``None`` and use a :class:`~torchrl.envs.transforms.StepCounter` transform. Examples: >>> from meltingpot import substrate >>> from torchrl.envs.libs.meltingpot import MeltingpotWrapper >>> substrate_config = substrate.get_config("commons_harvest__open") >>> mp_env = substrate.build_from_config( ... substrate_config, roles=substrate_config.default_player_roles ... ) >>> env_torchrl = MeltingpotWrapper(env=mp_env) >>> print(env_torchrl.rollout(max_steps=5)) TensorDict( fields={ RGB: Tensor(shape=torch.Size([5, 144, 192, 3]), device=cpu, dtype=torch.uint8, is_shared=False), agents: TensorDict( fields={ action: Tensor(shape=torch.Size([5, 7]), device=cpu, dtype=torch.int64, is_shared=False), observation: TensorDict( fields={ COLLECTIVE_REWARD: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False), READY_TO_SHOOT: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False), RGB: Tensor(shape=torch.Size([5, 7, 88, 88, 3]), device=cpu, dtype=torch.uint8, is_shared=False)}, batch_size=torch.Size([5, 7]), device=cpu, is_shared=False)}, batch_size=torch.Size([5, 7]), device=cpu, is_shared=False), done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ RGB: Tensor(shape=torch.Size([5, 144, 192, 3]), device=cpu, dtype=torch.uint8, is_shared=False), agents: TensorDict( fields={ observation: TensorDict( fields={ COLLECTIVE_REWARD: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False), READY_TO_SHOOT: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False), RGB: Tensor(shape=torch.Size([5, 7, 88, 88, 3]), device=cpu, dtype=torch.uint8, is_shared=False)}, batch_size=torch.Size([5, 7]), device=cpu, is_shared=False), reward: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False)}, batch_size=torch.Size([5, 7]), device=cpu, is_shared=False), done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([5]), device=cpu, is_shared=False), terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([5]), device=cpu, is_shared=False) """ git_url = "https://github.com/google-deepmind/meltingpot" libname = "melitingpot" @property def lib(self): import meltingpot return meltingpot @_classproperty def available_envs(cls): if not _has_meltingpot: return [] return _get_envs() def __init__( self, env: "meltingpot.utils.substrates.substrate.Substrate" = None, # noqa categorical_actions: bool = True, group_map: MarlGroupMapType | Dict[str, List[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, max_steps: int = None, **kwargs, ): if env is not None: kwargs["env"] = env self.group_map = group_map self.categorical_actions = categorical_actions self.max_steps = max_steps self.num_cycles = 0 super().__init__(**kwargs) def _build_env( self, env: "meltingpot.utils.substrates.substrate.Substrate", # noqa ): return env def _make_group_map(self): if isinstance(self.group_map, MarlGroupMapType): self.group_map = self.group_map.get_group_map(self.agent_names) check_marl_grouping(self.group_map, self.agent_names) def _make_specs( self, env: "meltingpot.utils.substrates.substrate.Substrate" # noqa ) -> None: mp_obs_spec = self._env.observation_spec() # List of dict of arrays mp_obs_spec_no_world = _remove_world_observations_from_obs_spec( mp_obs_spec ) # List of dict of arrays mp_global_state_spec = _global_state_spec_from_obs_spec( mp_obs_spec ) # Dict of arrays mp_act_spec = self._env.action_spec() # List of discrete arrays mp_rew_spec = self._env.reward_spec() # List of arrays torchrl_agent_obs_specs = [ _dmcontrol_to_torchrl_spec_transform(agent_obs_spec) for agent_obs_spec in mp_obs_spec_no_world ] torchrl_agent_act_specs = [ _dmcontrol_to_torchrl_spec_transform( agent_act_spec, categorical_discrete_encoding=self.categorical_actions ) for agent_act_spec in mp_act_spec ] torchrl_state_spec = _dmcontrol_to_torchrl_spec_transform(mp_global_state_spec) torchrl_rew_spec = [ _dmcontrol_to_torchrl_spec_transform(agent_rew_spec) for agent_rew_spec in mp_rew_spec ] # Create and check group map _num_players = len(torchrl_rew_spec) self.agent_names = [ PLAYER_STR_FORMAT.format(index=index) for index in range(_num_players) ] self.agent_names_to_indices_map = { agent_name: i for i, agent_name in enumerate(self.agent_names) } self._make_group_map() action_spec = Composite() observation_spec = Composite() reward_spec = Composite() for group in self.group_map.keys(): ( group_observation_spec, group_action_spec, group_reward_spec, ) = self._make_group_specs( group, torchrl_agent_obs_specs, torchrl_agent_act_specs, torchrl_rew_spec, ) action_spec[group] = group_action_spec observation_spec[group] = group_observation_spec reward_spec[group] = group_reward_spec observation_spec.update(torchrl_state_spec) self.done_spec = Composite( { "done": Categorical(n=2, shape=torch.Size((1,)), dtype=torch.bool), }, ) self.action_spec = action_spec self.observation_spec = observation_spec self.reward_spec = reward_spec def _make_group_specs( self, group: str, torchrl_agent_obs_specs: List[TensorSpec], torchrl_agent_act_specs: List[TensorSpec], torchrl_rew_spec: List[TensorSpec], ): # Agent specs action_specs = [] observation_specs = [] reward_specs = [] for agent_name in self.group_map[group]: agent_index = self.agent_names_to_indices_map[agent_name] action_specs.append( Composite( { "action": torchrl_agent_act_specs[ agent_index ] # shape = (n_actions_per_agent,) }, ) ) observation_specs.append( Composite( { "observation": torchrl_agent_obs_specs[ agent_index ] # shape = (n_obs_per_agent,) }, ) ) reward_specs.append( Composite({"reward": torchrl_rew_spec[agent_index]}) # shape = (1,) ) # Create multi-agent specs group_action_spec = torch.stack( action_specs, dim=0 ) # shape = (n_agents_in_group, n_actions_per_agent) group_observation_spec = torch.stack( observation_specs, dim=0 ) # shape = (n_agents_in_group, n_obs_per_agent) group_reward_spec = torch.stack( reward_specs, dim=0 ) # shape = (n_agents_in_group, 1) return ( group_observation_spec, group_action_spec, group_reward_spec, ) def _check_kwargs(self, kwargs: Dict): meltingpot = self.lib if "env" not in kwargs: raise TypeError("Could not find environment key 'env' in kwargs.") env = kwargs["env"] if not isinstance(env, meltingpot.utils.substrates.substrate.Substrate): raise TypeError( "env is not of type 'meltingpot.utils.substrates.substrate.Substrate'." ) def _init_env(self): # Caching self.cached_full_done_spec_zero = self.full_done_spec.zero() def _set_seed(self, seed: int | None): raise NotImplementedError( "It is currently unclear how to set a seed in Meltingpot. " "see https://github.com/google-deepmind/meltingpot/issues/129 to track the issue." ) def _reset( self, tensordict: TensorDictBase | None = None, **kwargs ) -> TensorDictBase: self.num_cycles = 0 timestep = self._env.reset() obs = timestep.observation td = self.cached_full_done_spec_zero.clone() for group, agent_names in self.group_map.items(): agent_tds = [] for index_in_group, agent_name in enumerate(agent_names): global_index = self.agent_names_to_indices_map[agent_name] agent_obs = self.observation_spec[group, "observation"][ index_in_group ].encode(_filter_global_state_from_dict(obs[global_index], world=False)) agent_td = TensorDict( source={ "observation": agent_obs, }, batch_size=self.batch_size, device=self.device, ) agent_tds.append(agent_td) agent_tds = torch.stack(agent_tds, dim=0) td.set(group, agent_tds) # Global state td.update( _remove_world_prefix(_filter_global_state_from_dict(obs[0], world=True)) ) tensordict_out = TensorDict( source=td, batch_size=self.batch_size, device=self.device, ) return tensordict_out def _step( self, tensordict: TensorDictBase, ) -> TensorDictBase: action_dict = {} for group, agents in self.group_map.items(): group_action = tensordict.get((group, "action")) group_action_np = self.full_action_spec[group, "action"].to_numpy( group_action ) for index, agent in enumerate(agents): action_dict[agent] = group_action_np[index] actions = [action_dict[agent] for agent in self.agent_names] timestep = self._env.step(actions) self.num_cycles += 1 rewards = timestep.reward done = timestep.last() or ( (self.num_cycles >= self.max_steps) if self.max_steps is not None else False ) obs = timestep.observation td = TensorDict( { "done": self.full_done_spec["done"].encode(done), "terminated": self.full_done_spec["terminated"].encode(done), }, batch_size=self.batch_size, ) # Global state td.update( _remove_world_prefix(_filter_global_state_from_dict(obs[0], world=True)) ) for group, agent_names in self.group_map.items(): agent_tds = [] for index_in_group, agent_name in enumerate(agent_names): global_index = self.agent_names_to_indices_map[agent_name] agent_obs = self.observation_spec[group, "observation"][ index_in_group ].encode(_filter_global_state_from_dict(obs[global_index], world=False)) agent_reward = self.full_reward_spec[group, "reward"][ index_in_group ].encode(rewards[global_index]) agent_td = TensorDict( source={ "observation": agent_obs, "reward": agent_reward, }, batch_size=self.batch_size, device=self.device, ) agent_tds.append(agent_td) agent_tds = torch.stack(agent_tds, dim=0) td.set(group, agent_tds) return td def get_rgb_image(self) -> torch.Tensor: """Returns an RGB image of the environment. Returns: a ``torch.Tensor`` containing image in format WHC. """ return torch.from_numpy(self._env.observation()[0][_WORLD_PREFIX + "RGB"])
[docs]class MeltingpotEnv(MeltingpotWrapper): """Meltingpot environment wrapper. GitHub: https://github.com/google-deepmind/meltingpot Paper: https://arxiv.org/abs/2211.13746 Melting Pot assesses generalization to novel social situations involving both familiar and unfamiliar individuals, and has been designed to test a broad range of social interactions such as: cooperation, competition, deception, reciprocation, trust, stubbornness and so on. Melting Pot offers researchers a set of over 50 multi-agent reinforcement learning substrates (multi-agent games) on which to train agents, and over 256 unique test scenarios on which to evaluate these trained agents. Args: substrate(str or ml_collections.config_dict.ConfigDict): the meltingpot substrate to build. Can be a string from :attr:`~.available_envs` or a ConfigDict for the substrate Keyword Args: max_steps (int, optional): Horizon of the task. Defaults to ``None`` (infinite horizon). Each Meltingpot substrate can be terminating or not. If ``max_steps`` is specified, the scenario is also terminated (and the ``"terminated"`` flag is set) whenever this horizon is reached. Unlike gym's ``TimeLimit`` transform or torchrl's :class:`~torchrl.envs.transforms.StepCounter`, this argument will not set the ``"truncated"`` entry in the tensordict. categorical_actions (bool, optional): if the environment actions are discrete, whether to transform them to categorical or one-hot. Defaults to ``True``. group_map (MarlGroupMapType or Dict[str, List[str]], optional): how to group agents in tensordicts for input/output. By default, they will be all put in one group named ``"agents"``. Otherwise, a group map can be specified or selected from some premade options. See :class:`~torchrl.envs.utils.MarlGroupMapType` for more info. Attributes: group_map (Dict[str, List[str]]): how to group agents in tensordicts for input/output. See :class:`~torchrl.envs.utils.MarlGroupMapType` for more info. agent_names (list of str): names of the agent in the environment agent_names_to_indices_map (Dict[str, int]): dictionary mapping agent names to their index in the enviornment available_envs (List[str]): the list of the scenarios available to build. .. warning:: Meltingpot returns a single ``done`` flag which does not distinguish between when the env reached ``max_steps`` and termination. If you deem the ``truncation`` signal necessary, set ``max_steps`` to ``None`` and use a :class:`~torchrl.envs.transforms.StepCounter` transform. Examples: >>> from torchrl.envs.libs.meltingpot import MeltingpotEnv >>> env_torchrl = MeltingpotEnv("commons_harvest__open") >>> print(env_torchrl.rollout(max_steps=5)) TensorDict( fields={ RGB: Tensor(shape=torch.Size([5, 144, 192, 3]), device=cpu, dtype=torch.uint8, is_shared=False), agents: TensorDict( fields={ action: Tensor(shape=torch.Size([5, 7]), device=cpu, dtype=torch.int64, is_shared=False), observation: TensorDict( fields={ COLLECTIVE_REWARD: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False), READY_TO_SHOOT: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False), RGB: Tensor(shape=torch.Size([5, 7, 88, 88, 3]), device=cpu, dtype=torch.uint8, is_shared=False)}, batch_size=torch.Size([5, 7]), device=cpu, is_shared=False)}, batch_size=torch.Size([5, 7]), device=cpu, is_shared=False), done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ RGB: Tensor(shape=torch.Size([5, 144, 192, 3]), device=cpu, dtype=torch.uint8, is_shared=False), agents: TensorDict( fields={ observation: TensorDict( fields={ COLLECTIVE_REWARD: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False), READY_TO_SHOOT: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False), RGB: Tensor(shape=torch.Size([5, 7, 88, 88, 3]), device=cpu, dtype=torch.uint8, is_shared=False)}, batch_size=torch.Size([5, 7]), device=cpu, is_shared=False), reward: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False)}, batch_size=torch.Size([5, 7]), device=cpu, is_shared=False), done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([5]), device=cpu, is_shared=False), terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([5]), device=cpu, is_shared=False) """ def __init__( self, substrate: str | "ml_collections.config_dict.ConfigDict", # noqa *, max_steps: int | None = None, categorical_actions: bool = True, group_map: MarlGroupMapType | Dict[str, List[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, **kwargs, ): if not _has_meltingpot: raise ImportError( f"meltingpot python package was not found. Please install this dependency. " f"More info: {self.git_url}." ) super().__init__( substrate=substrate, max_steps=max_steps, categorical_actions=categorical_actions, group_map=group_map, **kwargs, ) def _check_kwargs(self, kwargs: Dict): if "substrate" not in kwargs: raise TypeError("Could not find environment key 'substrate' in kwargs.") def _build_env( self, substrate: str | "ml_collections.config_dict.ConfigDict", # noqa ) -> "meltingpot.utils.substrates.substrate.Substrate": # noqa from meltingpot import substrate as mp_substrate if isinstance(substrate, str): substrate_config = mp_substrate.get_config(substrate) else: substrate_config = substrate return super()._build_env( env=mp_substrate.build_from_config( substrate_config, roles=substrate_config.default_player_roles ) )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources