Shortcuts

Source code for torchrl.envs.libs.smacv2

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import re

from typing import Dict, Optional

import torch
from tensordict import TensorDict, TensorDictBase

from torchrl.data.tensor_specs import (
    BoundedTensorSpec,
    CompositeSpec,
    DiscreteTensorSpec,
    OneHotDiscreteTensorSpec,
    UnboundedContinuousTensorSpec,
)
from torchrl.envs.common import _EnvWrapper

from torchrl.envs.utils import _classproperty, ACTION_MASK_ERROR

_has_smacv2 = importlib.util.find_spec("smacv2") is not None


def _get_envs():
    if not _has_smacv2:
        raise ImportError("SMAC-v2 is not installed in your virtual environment.")
    from smacv2.env.starcraft2.maps import smac_maps

    return list(smac_maps.get_smac_map_registry().keys())


[docs]class SMACv2Wrapper(_EnvWrapper): """SMACv2 (StarCraft Multi-Agent Challenge v2) environment wrapper. To install the environment follow the following `guide <https://github.com/oxwhirl/smacv2#getting-started>`__. Examples: >>> from torchrl.envs.libs.smacv2 import SMACv2Wrapper >>> import smacv2 >>> print(SMACv2Wrapper.available_envs) ['10gen_terran', '10gen_zerg', '10gen_protoss', '3m', '8m', '25m', '5m_vs_6m', '8m_vs_9m', '10m_vs_11m', '27m_vs_30m', 'MMM', 'MMM2', '2s3z', '3s5z', '3s5z_vs_3s6z', '3s_vs_3z', '3s_vs_4z', '3s_vs_5z', '1c3s5z', '2m_vs_1z', 'corridor', '6h_vs_8z', '2s_vs_1sc', 'so_many_baneling', 'bane_vs_bane', '2c_vs_64zg'] >>> # You can use old SMAC maps >>> env = SMACv2Wrapper(smacv2.env.StarCraft2Env(map_name="MMM2"), categorical_actions=False) >>> print(env.rollout(5)) TensorDict( fields={ agents: TensorDict( fields={ action: Tensor(shape=torch.Size([5, 10, 18]), device=cpu, dtype=torch.int64, is_shared=False), action_mask: Tensor(shape=torch.Size([5, 10, 18]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([5, 10, 176]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([5, 10]), device=cpu, is_shared=False), done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), info: TensorDict( fields={ battle_won: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.bool, is_shared=False), dead_allies: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False), dead_enemies: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False), episode_limit: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([5]), device=cpu, is_shared=False), next: TensorDict( fields={ agents: TensorDict( fields={ action_mask: Tensor(shape=torch.Size([5, 10, 18]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([5, 10, 176]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([5, 10]), device=cpu, is_shared=False), done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), info: TensorDict( fields={ battle_won: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.bool, is_shared=False), dead_allies: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False), dead_enemies: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False), episode_limit: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([5]), device=cpu, is_shared=False), reward: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), state: Tensor(shape=torch.Size([5, 322]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([5]), device=cpu, is_shared=False), state: Tensor(shape=torch.Size([5, 322]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([5]), device=cpu, is_shared=False) >>> # Or the new features for procedural generation >>> distribution_config = { ... "n_units": 5, ... "n_enemies": 6, ... "team_gen": { ... "dist_type": "weighted_teams", ... "unit_types": ["marine", "marauder", "medivac"], ... "exception_unit_types": ["medivac"], ... "weights": [0.5, 0.2, 0.3], ... "observe": True, ... }, ... "start_positions": { ... "dist_type": "surrounded_and_reflect", ... "p": 0.5, ... "n_enemies": 5, ... "map_x": 32, ... "map_y": 32, ... }, ... } >>> env = SMACv2Wrapper( ... smacv2.env.StarCraft2Env( ... map_name="10gen_terran", ... capability_config=distribution_config, ... ) ... ) >>> print(env.rollout(4)) TensorDict( fields={ agents: TensorDict( fields={ action: Tensor(shape=torch.Size([4, 5, 12]), device=cpu, dtype=torch.int64, is_shared=False), action_mask: Tensor(shape=torch.Size([4, 5, 12]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([4, 5, 88]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4, 5]), device=cpu, is_shared=False), done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), info: TensorDict( fields={ battle_won: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False), dead_allies: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.int64, is_shared=False), dead_enemies: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.int64, is_shared=False), episode_limit: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([4]), device=cpu, is_shared=False), next: TensorDict( fields={ agents: TensorDict( fields={ action_mask: Tensor(shape=torch.Size([4, 5, 12]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([4, 5, 88]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4, 5]), device=cpu, is_shared=False), done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), info: TensorDict( fields={ battle_won: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False), dead_allies: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.int64, is_shared=False), dead_enemies: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.int64, is_shared=False), episode_limit: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([4]), device=cpu, is_shared=False), reward: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.float32, is_shared=False), state: Tensor(shape=torch.Size([4, 131]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([4]), device=cpu, is_shared=False), state: Tensor(shape=torch.Size([4, 131]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([4]), device=cpu, is_shared=False) """ git_url = "https://github.com/oxwhirl/smacv2" libname = "smacv2" @_classproperty def available_envs(cls): if not _has_smacv2: return [] return list(_get_envs()) def __init__( self, env: "smacv2.env.StarCraft2Env" = None, # noqa: F821 categorical_actions: bool = True, **kwargs, ): if env is not None: kwargs["env"] = env self.categorical_actions = categorical_actions super().__init__(**kwargs) @property def lib(self): import smacv2 return smacv2 def _check_kwargs(self, kwargs: Dict): import smacv2 if "env" not in kwargs: raise TypeError("Could not find environment key 'env' in kwargs.") env = kwargs["env"] if not isinstance(env, smacv2.env.StarCraft2Env): raise TypeError("env is not of type 'smacv2.env.StarCraft2Env'.") def _build_env( self, env: "smacv2.env.StarCraft2Env", # noqa: F821 ): if len(self.batch_size): raise RuntimeError( f"SMACv2 does not support custom batch_size {self.batch_size}." ) return env def _make_specs(self, env: "smacv2.env.StarCraft2Env") -> None: # noqa: F821 self.group_map = {"agents": [str(i) for i in range(self.n_agents)]} self.reward_spec = UnboundedContinuousTensorSpec( shape=torch.Size((1,)), device=self.device, ) self.done_spec = DiscreteTensorSpec( n=2, shape=torch.Size((1,)), dtype=torch.bool, device=self.device, ) self.action_spec = self._make_action_spec() self.observation_spec = self._make_observation_spec() def _init_env(self) -> None: self._env.reset() self._update_action_mask() def _make_action_spec(self) -> CompositeSpec: if self.categorical_actions: action_spec = DiscreteTensorSpec( self.n_actions, shape=torch.Size((self.n_agents,)), device=self.device, dtype=torch.long, ) else: action_spec = OneHotDiscreteTensorSpec( self.n_actions, shape=torch.Size((self.n_agents, self.n_actions)), device=self.device, dtype=torch.long, ) spec = CompositeSpec( { "agents": CompositeSpec( {"action": action_spec}, shape=torch.Size((self.n_agents,)) ) } ) return spec def _make_observation_spec(self) -> CompositeSpec: obs_spec = BoundedTensorSpec( low=-1.0, high=1.0, shape=torch.Size([self.n_agents, self.get_obs_size()]), device=self.device, dtype=torch.float32, ) info_spec = CompositeSpec( { "battle_won": DiscreteTensorSpec( 2, dtype=torch.bool, device=self.device ), "episode_limit": DiscreteTensorSpec( 2, dtype=torch.bool, device=self.device ), "dead_allies": BoundedTensorSpec( low=0, high=self.n_agents, dtype=torch.long, device=self.device, shape=(), ), "dead_enemies": BoundedTensorSpec( low=0, high=self.n_enemies, dtype=torch.long, device=self.device, shape=(), ), } ) mask_spec = DiscreteTensorSpec( 2, torch.Size([self.n_agents, self.n_actions]), device=self.device, dtype=torch.bool, ) spec = CompositeSpec( { "agents": CompositeSpec( {"observation": obs_spec, "action_mask": mask_spec}, shape=torch.Size((self.n_agents,)), ), "state": BoundedTensorSpec( low=-1.0, high=1.0, shape=torch.Size((self.get_state_size(),)), device=self.device, dtype=torch.float32, ), "info": info_spec, } ) return spec def _set_seed(self, seed: Optional[int]): if seed is not None: raise NotImplementedError( "Seed cannot be changed once environment was created." ) def get_obs(self): obs = self._env.get_obs() return self._to_tensor(obs) def get_state(self): state = self._env.get_state() return self._to_tensor(state) def _to_tensor(self, value): return torch.tensor(value, device=self.device, dtype=torch.float32) def _reset( self, tensordict: Optional[TensorDictBase] = None, **kwargs ) -> TensorDictBase: obs, state = self._env.reset() # collect outputs obs = self._to_tensor(obs) state = self._to_tensor(state) info = self.observation_spec["info"].zero() mask = self._update_action_mask() # build results agents_td = TensorDict( {"observation": obs, "action_mask": mask}, batch_size=(self.n_agents,) ) tensordict_out = TensorDict( source={"agents": agents_td, "state": state, "info": info}, batch_size=(), device=self.device, ) return tensordict_out def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # perform actions action = tensordict.get(("agents", "action")) action_np = self.action_spec.to_numpy(action) # Actions are validated by the environment. try: reward, done, info = self._env.step(action_np) except AssertionError as err: if re.match(r"Agent . cannot perform action .", str(err)): raise ACTION_MASK_ERROR else: raise err # collect outputs obs = self.get_obs() state = self.get_state() info = self.observation_spec["info"].encode(info) actual_keys = info.keys() for expected_key, spec in self.observation_spec["info"].items(): if expected_key not in actual_keys: info[expected_key] = spec.zero() reward = torch.tensor( reward, device=self.device, dtype=torch.float32 ).unsqueeze(-1) done = torch.tensor(done, device=self.device, dtype=torch.bool).unsqueeze(-1) mask = self._update_action_mask() # build results agents_td = TensorDict( {"observation": obs, "action_mask": mask}, batch_size=(self.n_agents,) ) tensordict_out = TensorDict( source={ "agents": agents_td, "state": state, "info": info, "reward": reward, "done": done, "terminated": done.clone(), }, batch_size=(), device=self.device, ) return tensordict_out def _update_action_mask(self): mask = torch.tensor( self.get_avail_actions(), dtype=torch.bool, device=self.device ) self.action_spec.update_mask(mask) return mask def close(self): # Closes StarCraft II self._env.close() def get_agent_type(self, agent_index: int) -> str: """Get the agent type string. Given the agent index, get its unit type name. Args: agent_index (int): the index of the agent to get the type of """ if agent_index < 0 or agent_index >= self.n_agents: raise ValueError(f"Agent index out of range, {self.n_agents} available") agent_info = self.agents[agent_index] if agent_info.unit_type == self.marine_id: return "marine" elif agent_info.unit_type == self.marauder_id: return "marauder" elif agent_info.unit_type == self.medivac_id: return "medivac" elif agent_info.unit_type == self.hydralisk_id: return "hydralisk" elif agent_info.unit_type == self.zergling_id: return "zergling" elif agent_info.unit_type == self.baneling_id: return "baneling" elif agent_info.unit_type == self.stalker_id: return "stalker" elif agent_info.unit_type == self.colossus_id: return "colossus" elif agent_info.unit_type == self.zealot_id: return "zealot" else: raise AssertionError(f"Agent type {agent_info.unit_type} unidentified") # This patches the bug in https://github.com/oxwhirl/smacv2/issues/33 def render(self, mode: str = "human"): import smacv2 if isinstance(self._env, smacv2.env.StarCraftCapabilityEnvWrapper): return self._env.env.render(mode=mode) else: return self._env.render(mode=mode)
[docs]class SMACv2Env(SMACv2Wrapper): """SMACv2 (StarCraft Multi-Agent Challenge v2) environment wrapper. To install the environment follow the following `guide <https://github.com/oxwhirl/smacv2#getting-started>`__. Examples: >>> from torchrl.envs.libs.smacv2 import SMACv2Env >>> print(SMACv2Env.available_envs) ['10gen_terran', '10gen_zerg', '10gen_protoss', '3m', '8m', '25m', '5m_vs_6m', '8m_vs_9m', '10m_vs_11m', '27m_vs_30m', 'MMM', 'MMM2', '2s3z', '3s5z', '3s5z_vs_3s6z', '3s_vs_3z', '3s_vs_4z', '3s_vs_5z', '1c3s5z', '2m_vs_1z', 'corridor', '6h_vs_8z', '2s_vs_1sc', 'so_many_baneling', 'bane_vs_bane', '2c_vs_64zg'] >>> # You can use old SMAC maps >>> env = SMACv2Env(map_name="MMM2") >>> print(env.rollout(5) TensorDict( fields={ agents: TensorDict( fields={ action: Tensor(shape=torch.Size([5, 10, 18]), device=cpu, dtype=torch.int64, is_shared=False), action_mask: Tensor(shape=torch.Size([5, 10, 18]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([5, 10, 176]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([5, 10]), device=cpu, is_shared=False), done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), info: TensorDict( fields={ battle_won: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.bool, is_shared=False), dead_allies: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False), dead_enemies: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False), episode_limit: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([5]), device=cpu, is_shared=False), next: TensorDict( fields={ agents: TensorDict( fields={ action_mask: Tensor(shape=torch.Size([5, 10, 18]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([5, 10, 176]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([5, 10]), device=cpu, is_shared=False), done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), info: TensorDict( fields={ battle_won: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.bool, is_shared=False), dead_allies: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False), dead_enemies: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False), episode_limit: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([5]), device=cpu, is_shared=False), reward: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), state: Tensor(shape=torch.Size([5, 322]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([5]), device=cpu, is_shared=False), state: Tensor(shape=torch.Size([5, 322]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([5]), device=cpu, is_shared=False) >>> # Or the new features for procedural generation >>> distribution_config = { ... "n_units": 5, ... "n_enemies": 6, ... "team_gen": { ... "dist_type": "weighted_teams", ... "unit_types": ["marine", "marauder", "medivac"], ... "exception_unit_types": ["medivac"], ... "weights": [0.5, 0.2, 0.3], ... "observe": True, ... }, ... "start_positions": { ... "dist_type": "surrounded_and_reflect", ... "p": 0.5, ... "n_enemies": 5, ... "map_x": 32, ... "map_y": 32, ... }, ... } >>> env = SMACv2Env( ... map_name="10gen_terran", ... capability_config=distribution_config, ... categorical_actions=False, ... ) >>> print(env.rollout(4)) TensorDict( fields={ agents: TensorDict( fields={ action: Tensor(shape=torch.Size([4, 5, 12]), device=cpu, dtype=torch.int64, is_shared=False), action_mask: Tensor(shape=torch.Size([4, 5, 12]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([4, 5, 88]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4, 5]), device=cpu, is_shared=False), done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), info: TensorDict( fields={ battle_won: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False), dead_allies: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.int64, is_shared=False), dead_enemies: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.int64, is_shared=False), episode_limit: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([4]), device=cpu, is_shared=False), next: TensorDict( fields={ agents: TensorDict( fields={ action_mask: Tensor(shape=torch.Size([4, 5, 12]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([4, 5, 88]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4, 5]), device=cpu, is_shared=False), done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), info: TensorDict( fields={ battle_won: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False), dead_allies: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.int64, is_shared=False), dead_enemies: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.int64, is_shared=False), episode_limit: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([4]), device=cpu, is_shared=False), reward: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.float32, is_shared=False), state: Tensor(shape=torch.Size([4, 131]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([4]), device=cpu, is_shared=False), state: Tensor(shape=torch.Size([4, 131]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([4]), device=cpu, is_shared=False) """ def __init__( self, map_name: str, capability_config: Optional[Dict] = None, seed: Optional[int] = None, categorical_actions: bool = True, **kwargs, ): if not _has_smacv2: raise ImportError( f"smacv2 python package was not found. Please install this dependency. " f"More info: {self.git_url}." ) kwargs["map_name"] = map_name kwargs["capability_config"] = capability_config kwargs["seed"] = seed kwargs["categorical_actions"] = categorical_actions super().__init__(**kwargs) def _check_kwargs(self, kwargs: Dict): if "map_name" not in kwargs: raise TypeError("Expected 'map_name' to be part of kwargs") def _build_env( self, map_name: str, capability_config: Optional[Dict] = None, seed: Optional[int] = None, **kwargs, ) -> "smacv2.env.StarCraft2Env": # noqa: F821 import smacv2.env if capability_config is not None: env = smacv2.env.StarCraftCapabilityEnvWrapper( capability_config=capability_config, map_name=map_name, seed=seed, **kwargs, ) else: env = smacv2.env.StarCraft2Env(map_name=map_name, seed=seed, **kwargs) return super()._build_env(env)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources