Shortcuts

Source code for torchrl.envs.libs.vmas

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import importlib.util

from typing import Dict, List, Optional, Union

import torch
from tensordict.tensordict import TensorDict, TensorDictBase

from torchrl.data import (
    BoundedTensorSpec,
    CompositeSpec,
    DEVICE_TYPING,
    DiscreteTensorSpec,
    LazyStackedCompositeSpec,
    MultiDiscreteTensorSpec,
    MultiOneHotDiscreteTensorSpec,
    OneHotDiscreteTensorSpec,
    TensorSpec,
    UnboundedContinuousTensorSpec,
)
from torchrl.data.utils import numpy_to_torch_dtype_dict
from torchrl.envs.common import _EnvWrapper, EnvBase
from torchrl.envs.libs.gym import gym_backend, set_gym_backend
from torchrl.envs.utils import (
    _classproperty,
    _selective_unsqueeze,
    check_marl_grouping,
    MarlGroupMapType,
)

_has_vmas = importlib.util.find_spec("vmas") is not None


__all__ = ["VmasWrapper", "VmasEnv"]


def _get_envs():
    if not _has_vmas:
        raise ImportError("VMAS is not installed in your virtual environment.")
    import vmas

    all_scenarios = vmas.scenarios + vmas.mpe_scenarios + vmas.debug_scenarios
    # TODO heterogenous spaces
    # For now torchrl does not support heterogenous spaces (Tple(Box)) so many OpenAI MPE scenarios do not work
    heterogenous_spaces_scenarios = [
        "simple_adversary",
        "simple_crypto",
        "simple_push",
        "simple_speaker_listener",
        "simple_tag",
        "simple_world_comm",
    ]

    return [
        scenario
        for scenario in all_scenarios
        if scenario not in heterogenous_spaces_scenarios
    ]


@set_gym_backend("gym")
def _vmas_to_torchrl_spec_transform(
    spec,
    device,
    categorical_action_encoding,
) -> TensorSpec:
    gym_spaces = gym_backend("spaces")
    if isinstance(spec, gym_spaces.discrete.Discrete):
        action_space_cls = (
            DiscreteTensorSpec
            if categorical_action_encoding
            else OneHotDiscreteTensorSpec
        )
        dtype = (
            numpy_to_torch_dtype_dict[spec.dtype]
            if categorical_action_encoding
            else torch.long
        )
        return action_space_cls(spec.n, device=device, dtype=dtype)
    elif isinstance(spec, gym_spaces.multi_discrete.MultiDiscrete):
        dtype = (
            numpy_to_torch_dtype_dict[spec.dtype]
            if categorical_action_encoding
            else torch.long
        )
        return (
            MultiDiscreteTensorSpec(spec.nvec, device=device, dtype=dtype)
            if categorical_action_encoding
            else MultiOneHotDiscreteTensorSpec(spec.nvec, device=device, dtype=dtype)
        )
    elif isinstance(spec, gym_spaces.Box):
        shape = spec.shape
        if not len(shape):
            shape = torch.Size([1])
        dtype = numpy_to_torch_dtype_dict[spec.dtype]
        low = torch.tensor(spec.low, device=device, dtype=dtype)
        high = torch.tensor(spec.high, device=device, dtype=dtype)
        is_unbounded = low.isinf().all() and high.isinf().all()
        return (
            UnboundedContinuousTensorSpec(shape, device=device, dtype=dtype)
            if is_unbounded
            else BoundedTensorSpec(
                low,
                high,
                shape,
                dtype=dtype,
                device=device,
            )
        )
    else:
        raise NotImplementedError(
            f"spec of type {type(spec).__name__} is currently unaccounted for vmas"
        )


[docs]class VmasWrapper(_EnvWrapper): """Vmas environment wrapper. Args: env (``vmas.simulator.environment.environment.Environment``): the vmas environment to wrap. categorical_actions (bool, optional): if the environment actions are discrete, whether to transform them to categorical or one-hot. group_map (MarlGroupMapType or Dict[str, List[str]], optional): how to group agents in tensordicts for input/output. By default, if the agent names follow the ``"<name>_<int>"`` convention, they will be grouped by ``"<name>"``. If they do not follow this convention, they will be all put in one group named ``"agents"``. Otherwise, a group map can be specified or selected from some premade options. See :class:`~torchrl.envs.utils.MarlGroupMapType` for more info. Attributes: group_map (Dict[str, List[str]]): how to group agents in tensordicts for input/output. See :class:`~torchrl.envs.utils.MarlGroupMapType` for more info. agent_names (list of str): names of the agent in the environment agent_names_to_indices_map (Dict[str, int]): dictionary mapping agent names to their index in the enviornment unbatched_action_spec (TensorSpec): version of the spec without the vectorized dimension unbatched_observation_spec (TensorSpec): version of the spec without the vectorized dimension unbatched_reward_spec (TensorSpec): version of the spec without the vectorized dimension het_specs (bool): whether the enviornment has any lazy spec het_specs_map (Dict[str, bool]): dictionary mapping each group to a flag representing of the group has lazy specs Examples: >>> env = VmasWrapper( ... vmas.make_env( ... scenario="flocking", ... num_envs=32, ... continuous_actions=True, ... max_steps=200, ... device="cpu", ... seed=None, ... # Scenario kwargs ... n_agents=5, ... ) ... ) >>> print(env.rollout(10)) TensorDict( fields={ agents: TensorDict( fields={ action: Tensor(shape=torch.Size([32, 10, 5, 2]), device=cpu, dtype=torch.float32, is_shared=False), info: TensorDict( fields={ agent_collision_rew: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False), agent_distance_rew: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([32, 10, 5]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([32, 10, 5, 18]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([32, 10, 5]), device=cpu, is_shared=False), done: Tensor(shape=torch.Size([32, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ agents: TensorDict( fields={ info: TensorDict( fields={ agent_collision_rew: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False), agent_distance_rew: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([32, 10, 5]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([32, 10, 5, 18]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([32, 10, 5]), device=cpu, is_shared=False), done: Tensor(shape=torch.Size([32, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False), terminated: Tensor(shape=torch.Size([32, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([32, 10]), device=cpu, is_shared=False), terminated: Tensor(shape=torch.Size([32, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([32, 10]), device=cpu, is_shared=False) """ git_url = "https://github.com/proroklab/VectorizedMultiAgentSimulator" libname = "vmas" @property def lib(self): import vmas return vmas @_classproperty def available_envs(cls): if not _has_vmas: return yield from _get_envs() def __init__( self, env: "vmas.simulator.environment.environment.Environment" = None, # noqa categorical_actions: bool = True, group_map: MarlGroupMapType | Dict[str, List[str]] | None = None, **kwargs, ): if env is not None: kwargs["env"] = env if "device" in kwargs.keys() and kwargs["device"] != str(env.device): raise TypeError("Env device is different from vmas device") kwargs["device"] = str(env.device) self.group_map = group_map self.categorical_actions = categorical_actions super().__init__(**kwargs, allow_done_after_reset=True) def _build_env( self, env: "vmas.simulator.environment.environment.Environment", # noqa from_pixels: bool = False, pixels_only: bool = False, ): self.from_pixels = from_pixels self.pixels_only = pixels_only # TODO pixels if self.from_pixels: raise NotImplementedError("vmas rendering not yet implemented") # Adjust batch size if len(self.batch_size) == 0: # Batch size not set self.batch_size = torch.Size((env.num_envs,)) elif len(self.batch_size) == 1: # Batch size is set if not self.batch_size[0] == env.num_envs: raise TypeError( "Batch size used in constructor does not match vmas batch size." ) else: raise TypeError( "Batch size used in constructor is not compatible with vmas." ) return env def _get_default_group_map(self, agent_names: List[str]): # This function performs the default grouping in vmas. # Agents with names "<name>_<int>" will be grouped in group name "<name>". # If any of the agents does not follow the naming convention, we fall back # back on having all agents in one group named "agents". group_map = {} follows_convention = True for agent_name in agent_names: # See if the agent follows the convention "<name>_<int>" agent_name_split = agent_name.split("_") if len(agent_name_split) == 1: follows_convention = False follows_convention = follows_convention and agent_name_split[-1].isdigit() if not follows_convention: break # Group it with other agents that follow the same convention group_name = "_".join(agent_name_split[:-1]) if group_name in group_map: group_map[group_name].append(agent_name) else: group_map[group_name] = [agent_name] if not follows_convention: group_map = MarlGroupMapType.ALL_IN_ONE_GROUP.get_group_map(agent_names) # For BC-compatibility rename the "agent" group to "agents" if "agent" in group_map: agent_group = group_map["agent"] group_map["agents"] = agent_group del group_map["agent"] return group_map def _make_specs( self, env: "vmas.simulator.environment.environment.Environment" # noqa ) -> None: # Create and check group map self.agent_names = [agent.name for agent in self.agents] self.agent_names_to_indices_map = { agent.name: i for i, agent in enumerate(self.agents) } if self.group_map is None: self.group_map = self._get_default_group_map(self.agent_names) elif isinstance(self.group_map, MarlGroupMapType): self.group_map = self.group_map.get_group_map(self.agent_names) check_marl_grouping(self.group_map, self.agent_names) self.unbatched_action_spec = CompositeSpec(device=self.device) self.unbatched_observation_spec = CompositeSpec(device=self.device) self.unbatched_reward_spec = CompositeSpec(device=self.device) self.het_specs = False self.het_specs_map = {} for group in self.group_map.keys(): ( group_observation_spec, group_action_spec, group_reward_spec, group_info_spec, ) = self._make_unbatched_group_specs(group) self.unbatched_action_spec[group] = group_action_spec self.unbatched_observation_spec[group] = group_observation_spec self.unbatched_reward_spec[group] = group_reward_spec if group_info_spec is not None: self.unbatched_observation_spec[(group, "info")] = group_info_spec group_het_specs = isinstance( group_observation_spec, LazyStackedCompositeSpec ) or isinstance(group_action_spec, LazyStackedCompositeSpec) self.het_specs_map[group] = group_het_specs self.het_specs = self.het_specs or group_het_specs self.unbatched_done_spec = CompositeSpec( { "done": DiscreteTensorSpec( n=2, shape=torch.Size((1,)), dtype=torch.bool, device=self.device, ), }, ) self.action_spec = self.unbatched_action_spec.expand( *self.batch_size, *self.unbatched_action_spec.shape ) self.observation_spec = self.unbatched_observation_spec.expand( *self.batch_size, *self.unbatched_observation_spec.shape ) self.reward_spec = self.unbatched_reward_spec.expand( *self.batch_size, *self.unbatched_reward_spec.shape ) self.done_spec = self.unbatched_done_spec.expand( *self.batch_size, *self.unbatched_done_spec.shape ) def _make_unbatched_group_specs(self, group: str): # Agent specs action_specs = [] observation_specs = [] reward_specs = [] info_specs = [] for agent_name in self.group_map[group]: agent_index = self.agent_names_to_indices_map[agent_name] agent = self.agents[agent_index] action_specs.append( CompositeSpec( { "action": _vmas_to_torchrl_spec_transform( self.action_space[agent_index], categorical_action_encoding=self.categorical_actions, device=self.device, ) # shape = (n_actions_per_agent,) }, ) ) observation_specs.append( CompositeSpec( { "observation": _vmas_to_torchrl_spec_transform( self.observation_space[agent_index], device=self.device, categorical_action_encoding=self.categorical_actions, ) # shape = (n_obs_per_agent,) }, ) ) reward_specs.append( CompositeSpec( { "reward": UnboundedContinuousTensorSpec( shape=torch.Size((1,)), device=self.device, ) # shape = (1,) } ) ) agent_info = self.scenario.info(agent) if len(agent_info): info_specs.append( CompositeSpec( { key: UnboundedContinuousTensorSpec( shape=_selective_unsqueeze( value, batch_size=self.batch_size ).shape[1:], device=self.device, dtype=torch.float32, ) for key, value in agent_info.items() }, ).to(self.device) ) # Create multi-agent specs group_action_spec = torch.stack( action_specs, dim=0 ) # shape = (n_agents, n_actions_per_agent) group_observation_spec = torch.stack( observation_specs, dim=0 ) # shape = (n_agents, n_obs_per_agent) group_reward_spec = torch.stack(reward_specs, dim=0) # shape = (n_agents, 1) group_info_spec = None if len(info_specs): group_info_spec = torch.stack(info_specs, dim=0) return ( group_observation_spec, group_action_spec, group_reward_spec, group_info_spec, ) def _check_kwargs(self, kwargs: Dict): vmas = self.lib if "env" not in kwargs: raise TypeError("Could not find environment key 'env' in kwargs.") env = kwargs["env"] if not isinstance(env, vmas.simulator.environment.Environment): raise TypeError( "env is not of type 'vmas.simulator.environment.Environment'." ) def _init_env(self) -> Optional[int]: pass def _set_seed(self, seed: Optional[int]): self._env.seed(seed) def _reset( self, tensordict: Optional[TensorDictBase] = None, **kwargs ) -> TensorDictBase: if tensordict is not None and "_reset" in tensordict.keys(): _reset = tensordict.get("_reset") envs_to_reset = _reset.squeeze(-1) if envs_to_reset.all(): self._env.reset(return_observations=False) else: for env_index, to_reset in enumerate(envs_to_reset): if to_reset: self._env.reset_at(env_index, return_observations=False) else: self._env.reset(return_observations=False) obs, dones, infos = self._env.get_from_scenario( get_observations=True, get_infos=True, get_rewards=False, get_dones=True, ) dones = self.read_done(dones) source = {"done": dones, "terminated": dones.clone()} for group, agent_names in self.group_map.items(): agent_tds = [] for agent_name in agent_names: i = self.agent_names_to_indices_map[agent_name] agent_obs = self.read_obs(obs[i]) agent_info = self.read_info(infos[i]) agent_td = TensorDict( source={ "observation": agent_obs, }, batch_size=self.batch_size, device=self.device, ) if agent_info is not None: agent_td.set("info", agent_info) agent_tds.append(agent_td) agent_tds = torch.stack(agent_tds, dim=1) if not self.het_specs_map[group]: agent_tds = agent_tds.to_tensordict() source.update({group: agent_tds}) tensordict_out = TensorDict( source=source, batch_size=self.batch_size, device=self.device, ) return tensordict_out def _step( self, tensordict: TensorDictBase, ) -> TensorDictBase: agent_indices = {} action_list = [] n_agents = 0 for group, agent_names in self.group_map.items(): group_action = tensordict.get((group, "action")) group_action_list = list(self.read_action(group_action, group=group)) agent_indices.update( { self.agent_names_to_indices_map[agent_name]: i + n_agents for i, agent_name in enumerate(agent_names) } ) n_agents += len(agent_names) action_list += group_action_list action = [action_list[agent_indices[i]] for i in range(self.n_agents)] obs, rews, dones, infos = self._env.step(action) dones = self.read_done(dones) source = {"done": dones, "terminated": dones.clone()} for group, agent_names in self.group_map.items(): agent_tds = [] for agent_name in agent_names: i = self.agent_names_to_indices_map[agent_name] agent_obs = self.read_obs(obs[i]) agent_rew = self.read_reward(rews[i]) agent_info = self.read_info(infos[i]) agent_td = TensorDict( source={ "observation": agent_obs, "reward": agent_rew, }, batch_size=self.batch_size, device=self.device, ) if agent_info is not None: agent_td.set("info", agent_info) agent_tds.append(agent_td) agent_tds = torch.stack(agent_tds, dim=1) if not self.het_specs_map[group]: agent_tds = agent_tds.to_tensordict() source.update({group: agent_tds}) tensordict_out = TensorDict( source=source, batch_size=self.batch_size, device=self.device, ) return tensordict_out def read_obs( self, observations: Union[Dict, torch.Tensor] ) -> Union[Dict, torch.Tensor]: if isinstance(observations, torch.Tensor): return _selective_unsqueeze(observations, batch_size=self.batch_size) return TensorDict( source={key: self.read_obs(value) for key, value in observations.items()}, batch_size=self.batch_size, ) def read_info(self, infos: Dict[str, torch.Tensor]) -> torch.Tensor: if len(infos) == 0: return None infos = TensorDict( source={ key: _selective_unsqueeze( value.to(torch.float32), batch_size=self.batch_size ) for key, value in infos.items() }, batch_size=self.batch_size, device=self.device, ) return infos def read_done(self, done): done = _selective_unsqueeze(done, batch_size=self.batch_size) return done def read_reward(self, rewards): rewards = _selective_unsqueeze(rewards, batch_size=self.batch_size) return rewards def read_action(self, action, group: str = "agents"): if not self.continuous_actions and not self.categorical_actions: action = self.unbatched_action_spec[group, "action"].to_categorical(action) agent_actions = action.unbind(dim=1) return agent_actions def __repr__(self) -> str: return ( f"{self.__class__.__name__}(num_envs={self.num_envs}, n_agents={self.n_agents}," f" batch_size={self.batch_size}, device={self.device})" ) def to(self, device: DEVICE_TYPING) -> EnvBase: self._env.to(device) return super().to(device)
[docs]class VmasEnv(VmasWrapper): """Vmas environment wrapper. Examples: >>> env = VmasEnv( ... scenario="flocking", ... num_envs=32, ... continuous_actions=True, ... max_steps=200, ... device="cpu", ... seed=None, ... # Scenario kwargs ... n_agents=5, ... ) >>> print(env.rollout(10)) TensorDict( fields={ agents: TensorDict( fields={ action: Tensor(shape=torch.Size([32, 10, 5, 2]), device=cpu, dtype=torch.float32, is_shared=False), info: TensorDict( fields={ agent_collision_rew: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False), agent_distance_rew: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([32, 10, 5]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([32, 10, 5, 18]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([32, 10, 5]), device=cpu, is_shared=False), done: Tensor(shape=torch.Size([32, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ agents: TensorDict( fields={ info: TensorDict( fields={ agent_collision_rew: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False), agent_distance_rew: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([32, 10, 5]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([32, 10, 5, 18]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([32, 10, 5]), device=cpu, is_shared=False), done: Tensor(shape=torch.Size([32, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False), terminated: Tensor(shape=torch.Size([32, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([32, 10]), device=cpu, is_shared=False), terminated: Tensor(shape=torch.Size([32, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([32, 10]), device=cpu, is_shared=False) """ def __init__( self, scenario: Union[str, "vmas.simulator.scenario.BaseScenario"], # noqa num_envs: int, continuous_actions: bool = True, max_steps: Optional[int] = None, categorical_actions: bool = True, seed: Optional[int] = None, group_map: MarlGroupMapType | Dict[str, List[str]] | None = None, **kwargs, ): if not _has_vmas: raise ImportError( f"vmas python package was not found. Please install this dependency. " f"More info: {self.git_url}." ) kwargs["scenario"] = scenario kwargs["num_envs"] = num_envs kwargs["continuous_actions"] = continuous_actions kwargs["max_steps"] = max_steps kwargs["seed"] = seed kwargs["categorical_actions"] = categorical_actions kwargs["group_map"] = group_map super().__init__(**kwargs) def _check_kwargs(self, kwargs: Dict): if "scenario" not in kwargs: raise TypeError("Could not find environment key 'scenario' in kwargs.") if "num_envs" not in kwargs: raise TypeError("Could not find environment key 'num_envs' in kwargs.") def _build_env( self, scenario: Union[str, "vmas.simulator.scenario.BaseScenario"], # noqa num_envs: int, continuous_actions: bool, max_steps: Optional[int], seed: Optional[int], **scenario_kwargs, ) -> "vmas.simulator.environment.environment.Environment": # noqa vmas = self.lib self.scenario_name = scenario from_pixels = scenario_kwargs.pop("from_pixels", False) pixels_only = scenario_kwargs.pop("pixels_only", False) return super()._build_env( env=vmas.make_env( scenario=scenario, num_envs=num_envs, device=self.device, continuous_actions=continuous_actions, max_steps=max_steps, seed=seed, wrapper=None, **scenario_kwargs, ), pixels_only=pixels_only, from_pixels=from_pixels, ) def __repr__(self): return f"{super().__repr__()} (scenario={self.scenario_name})"

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources