Source code for torchrl.envs.libs.pettingzoo
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import copy
import importlib
import warnings
from typing import Dict, List, Tuple, Union
import packaging
import torch
from tensordict import TensorDictBase
from torchrl.data.tensor_specs import (
CompositeSpec,
DiscreteTensorSpec,
OneHotDiscreteTensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.envs.common import _EnvWrapper
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform, set_gym_backend
from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType
_has_pettingzoo = importlib.util.find_spec("pettingzoo") is not None
def _get_envs():
if not _has_pettingzoo:
raise ImportError("PettingZoo is not installed in your virtual environment.")
try:
from pettingzoo.utils.all_modules import all_environments
except ModuleNotFoundError as err:
warnings.warn(
f"PettingZoo failed to load all modules with error message {err}, trying to load individual modules."
)
all_environments = _load_available_envs()
return list(all_environments.keys())
def _load_available_envs() -> Dict:
all_environments = {}
try:
from pettingzoo.mpe.all_modules import mpe_environments
all_environments.update(mpe_environments)
except ModuleNotFoundError as err:
warnings.warn(f"MPE environments failed to load with error message {err}.")
try:
from pettingzoo.sisl.all_modules import sisl_environments
all_environments.update(sisl_environments)
except ModuleNotFoundError as err:
warnings.warn(f"SISL environments failed to load with error message {err}.")
try:
from pettingzoo.classic.all_modules import classic_environments
all_environments.update(classic_environments)
except ModuleNotFoundError as err:
warnings.warn(f"Classic environments failed to load with error message {err}.")
try:
from pettingzoo.atari.all_modules import atari_environments
all_environments.update(atari_environments)
except ModuleNotFoundError as err:
warnings.warn(f"Atari environments failed to load with error message {err}.")
try:
from pettingzoo.butterfly.all_modules import butterfly_environments
all_environments.update(butterfly_environments)
except ModuleNotFoundError as err:
warnings.warn(
f"Butterfly environments failed to load with error message {err}."
)
return all_environments
[docs]class PettingZooWrapper(_EnvWrapper):
"""PettingZoo environment wrapper.
To install petting zoo follow the guide `here <https://github.com/Farama-Foundation/PettingZoo#installation>__`.
This class is a general torchrl wrapper for all PettingZoo environments.
It is able to wrap both ``pettingzoo.AECEnv`` and ``pettingzoo.ParallelEnv``.
Let's see how more in details:
In wrapped ``pettingzoo.ParallelEnv`` all agents will step at each environment step.
If the number of agents during the task varies, please set ``use_mask=True``.
``"mask"`` will be provided
as an output in each group and should be used to mask out dead agents.
The environment will be reset as soon as one agent is done (unless ``done_on_any`` is ``False``).
In wrapped ``pettingzoo.AECEnv``, at each step only one agent will act.
For this reason, it is compulsory to set ``use_mask=True`` for this type of environment.
``"mask"`` will be provided as an output for each group and can be used to mask out non-acting agents.
The environment will be reset only when all agents are done (unless ``done_on_any`` is ``True``).
If there are any unavailable actions for an agent,
the environment will also automatically update the mask of its ``action_spec`` and output an ``"action_mask"``
for each group to reflect the latest available actions. This should be passed to a masked distribution during
training.
As a feature of torchrl multiagent, you are able to control the grouping of agents in your environment.
You can group agents together (stacking their tensors) to leverage vectorization when passing them through the same
neural network. You can split agents in different groups where they are heterogenous or should be processed by
different neural networks. To group, you just need to pass a ``group_map`` at env constructiuon time.
By default, agents in pettingzoo will be grouped by name.
For example, with agents ``["agent_0","agent_1","agent_2","adversary_0"]``, the tensordicts will look like:
>>> print(env.rand_action(env.reset()))
TensorDict(
fields={
agent: TensorDict(
fields={
action: Tensor(shape=torch.Size([3, 9]), device=cpu, dtype=torch.int64, is_shared=False),
action_mask: Tensor(shape=torch.Size([3, 9]), device=cpu, dtype=torch.bool, is_shared=False),
done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([3, 3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False),
terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([3]))},
adversary: TensorDict(
fields={
action: Tensor(shape=torch.Size([1, 9]), device=cpu, dtype=torch.int64, is_shared=False),
action_mask: Tensor(shape=torch.Size([1, 9]), device=cpu, dtype=torch.bool, is_shared=False),
done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([1, 3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False),
terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([1]))},
batch_size=torch.Size([]))
>>> print(env.group_map)
{"agent": ["agent_0", "agent_1", "agent_2"], "adversary": ["adversary_0"]}
Otherwise, a group map can be specified or selected from some premade options.
See :class:`torchrl.envs.utils.MarlGroupMapType` for more info.
For example, you can provide ``MarlGroupMapType.ONE_GROUP_PER_AGENT``, telling that each agent should
have its own tensordict (similar to the pettingzoo parallel API).
Grouping is useful for leveraging vectorisation among agents whose data goes through the same
neural network.
Args:
env (``pettingzoo.utils.env.ParallelEnv`` or ``pettingzoo.utils.env.AECEnv``): the pettingzoo environment to wrap.
return_state (bool, optional): whether to return the global state from pettingzoo
(not available in all environments). Defaults to ``False``.
group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to group agents in tensordicts for
input/output. By default, agents will be grouped by their name. Otherwise, a group map can be specified
or selected from some premade options. See :class:`torchrl.envs.utils.MarlGroupMapType` for more info.
use_mask (bool, optional): whether the environment should output a ``"mask"``. This is compulsory in
wrapped ``pettingzoo.AECEnv`` to mask out non-acting agents and should be also used
for ``pettingzoo.ParallelEnv`` when the number of agents can vary. Defaults to ``False``.
categorical_actions (bool, optional): if the enviornments actions are discrete, whether to transform
them to categorical or one-hot.
seed (int, optional): the seed. Defaults to ``None``.
done_on_any (bool, optional): whether the environment's done keys are set by aggregating the agent keys
using ``any()`` (when ``True``) or ``all()`` (when ``False``). Default (``None``) is to use ``any()`` for
parallel environments and ``all()`` for AEC ones.
Examples:
>>> # Parallel env
>>> from torchrl.envs.libs.pettingzoo import PettingZooWrapper
>>> from pettingzoo.butterfly import pistonball_v6
>>> kwargs = {"n_pistons": 21, "continuous": True}
>>> env = PettingZooWrapper(
... env=pistonball_v6.parallel_env(**kwargs),
... return_state=True,
... group_map=None, # Use default for parallel (all pistons grouped together)
... )
>>> print(env.group_map)
... {'piston': ['piston_0', 'piston_1', ..., 'piston_20']}
>>> env.rollout(10)
>>> # AEC env
>>> from pettingzoo.classic import tictactoe_v3
>>> from torchrl.envs.libs.pettingzoo import PettingZooWrapper
>>> from torchrl.envs.utils import MarlGroupMapType
>>> env = PettingZooWrapper(
... env=tictactoe_v3.env(),
... use_mask=True, # Must use it since one player plays at a time
... group_map=None # # Use default for AEC (one group per player)
... )
>>> print(env.group_map)
... {'player_1': ['player_1'], 'player_2': ['player_2']}
>>> env.rollout(10)
"""
git_url = "https://github.com/Farama-Foundation/PettingZoo"
libname = "pettingzoo"
@_classproperty
def available_envs(cls):
if not _has_pettingzoo:
return []
return list(_get_envs())
def __init__(
self,
env: Union[
"pettingzoo.utils.env.ParallelEnv", # noqa: F821
"pettingzoo.utils.env.AECEnv", # noqa: F821
] = None,
return_state: bool = False,
group_map: MarlGroupMapType | Dict[str, List[str]] | None = None,
use_mask: bool = False,
categorical_actions: bool = True,
seed: int | None = None,
done_on_any: bool | None = None,
**kwargs,
):
if env is not None:
kwargs["env"] = env
self.group_map = group_map
self.return_state = return_state
self.seed = seed
self.use_mask = use_mask
self.categorical_actions = categorical_actions
self.done_on_any = done_on_any
super().__init__(**kwargs, allow_done_after_reset=True)
def _get_default_group_map(self, agent_names: List[str]):
# This function performs the default grouping in pettingzoo
if not self.parallel:
# In AEC envs we will have one group per agent by default
group_map = MarlGroupMapType.ONE_GROUP_PER_AGENT.get_group_map(agent_names)
else:
# In parallel envs, by default
# Agents with names "str_int" will be grouped in group name "str"
group_map = {}
for agent_name in agent_names:
# See if the agent follows the convention "name_int"
follows_convention = True
agent_name_split = agent_name.split("_")
if len(agent_name_split) == 1:
follows_convention = False
try:
int(agent_name_split[-1])
except ValueError:
follows_convention = False
# If not, just put it in a single group
if not follows_convention:
group_map[agent_name] = [agent_name]
# Otherwise, group it with other agents that follow the same convention
else:
group_name = "_".join(agent_name_split[:-1])
if group_name in group_map:
group_map[group_name].append(agent_name)
else:
group_map[group_name] = [agent_name]
return group_map
@property
def lib(self):
import pettingzoo
return pettingzoo
def _build_env(
self,
env: Union[
"pettingzoo.utils.env.ParallelEnv", # noqa: F821
"pettingzoo.utils.env.AECEnv", # noqa: F821
],
):
import pettingzoo
if packaging.version.parse(pettingzoo.__version__).base_version != "1.24.3":
warnings.warn(
"PettingZoo in TorchRL is tested using version == 1.24.3 , "
"If you are using a different version and are experiencing compatibility issues,"
"please raise an issue in the TorchRL github."
)
self.parallel = isinstance(env, pettingzoo.utils.env.ParallelEnv)
if not self.parallel and not self.use_mask:
raise ValueError("For AEC environments you need to set use_mask=True")
if len(self.batch_size):
raise RuntimeError(
f"PettingZoo does not support custom batch_size {self.batch_size}."
)
return env
@set_gym_backend("gymnasium")
def _make_specs(
self,
env: Union[
"pettingzoo.utils.env.ParallelEnv", # noqa: F821
"pettingzoo.utils.env.AECEnv", # noqa: F821
],
) -> None:
# Set default for done on any or all
if self.done_on_any is None:
self.done_on_any = self.parallel
# Create and check group map
if self.group_map is None:
self.group_map = self._get_default_group_map(self.possible_agents)
elif isinstance(self.group_map, MarlGroupMapType):
self.group_map = self.group_map.get_group_map(self.possible_agents)
check_marl_grouping(self.group_map, self.possible_agents)
self.has_action_mask = {group: False for group in self.group_map.keys()}
action_spec = CompositeSpec()
observation_spec = CompositeSpec()
reward_spec = CompositeSpec()
done_spec = CompositeSpec(
{
"done": DiscreteTensorSpec(
n=2,
shape=torch.Size((1,)),
dtype=torch.bool,
device=self.device,
),
"terminated": DiscreteTensorSpec(
n=2,
shape=torch.Size((1,)),
dtype=torch.bool,
device=self.device,
),
"truncated": DiscreteTensorSpec(
n=2,
shape=torch.Size((1,)),
dtype=torch.bool,
device=self.device,
),
},
)
for group, agents in self.group_map.items():
(
group_observation_spec,
group_action_spec,
group_reward_spec,
group_done_spec,
) = self._make_group_specs(group_name=group, agent_names=agents)
action_spec[group] = group_action_spec
observation_spec[group] = group_observation_spec
reward_spec[group] = group_reward_spec
done_spec[group] = group_done_spec
self.action_spec = action_spec
self.observation_spec = observation_spec
self.reward_spec = reward_spec
self.done_spec = done_spec
def _make_group_specs(self, group_name: str, agent_names: List[str]):
n_agents = len(agent_names)
action_specs = []
observation_specs = []
for agent in agent_names:
action_specs.append(
CompositeSpec(
{
"action": _gym_to_torchrl_spec_transform(
self.action_space(agent),
remap_state_to_observation=False,
categorical_action_encoding=self.categorical_actions,
device=self.device,
)
},
)
)
observation_specs.append(
CompositeSpec(
{
"observation": _gym_to_torchrl_spec_transform(
self.observation_space(agent),
remap_state_to_observation=False,
device=self.device,
)
}
)
)
group_action_spec = torch.stack(action_specs, dim=0)
group_observation_spec = torch.stack(observation_specs, dim=0)
# Sometimes the observation spec contains an action mask.
# Or sometimes the info spec contains an action mask.
# We uniform this by removing it from both places and optionally set it in a standard location.
group_observation_inner_spec = group_observation_spec["observation"]
if (
isinstance(group_observation_inner_spec, CompositeSpec)
and "action_mask" in group_observation_inner_spec.keys()
):
self.has_action_mask[group_name] = True
del group_observation_inner_spec["action_mask"]
group_observation_spec["action_mask"] = DiscreteTensorSpec(
n=2,
shape=group_action_spec["action"].shape
if not self.categorical_actions
else (
*group_action_spec["action"].shape,
group_action_spec["action"].space.n,
),
dtype=torch.bool,
device=self.device,
)
if self.use_mask:
group_observation_spec["mask"] = DiscreteTensorSpec(
n=2,
shape=torch.Size((n_agents,)),
dtype=torch.bool,
device=self.device,
)
group_reward_spec = CompositeSpec(
{
"reward": UnboundedContinuousTensorSpec(
shape=torch.Size((n_agents, 1)),
device=self.device,
dtype=torch.float32,
)
},
shape=torch.Size((n_agents,)),
)
group_done_spec = CompositeSpec(
{
"done": DiscreteTensorSpec(
n=2,
shape=torch.Size((n_agents, 1)),
dtype=torch.bool,
device=self.device,
),
"terminated": DiscreteTensorSpec(
n=2,
shape=torch.Size((n_agents, 1)),
dtype=torch.bool,
device=self.device,
),
"truncated": DiscreteTensorSpec(
n=2,
shape=torch.Size((n_agents, 1)),
dtype=torch.bool,
device=self.device,
),
},
shape=torch.Size((n_agents,)),
)
return (
group_observation_spec,
group_action_spec,
group_reward_spec,
group_done_spec,
)
def _check_kwargs(self, kwargs: Dict):
import pettingzoo
if "env" not in kwargs:
raise TypeError("Could not find environment key 'env' in kwargs.")
env = kwargs["env"]
if not isinstance(
env, (pettingzoo.utils.env.ParallelEnv, pettingzoo.utils.env.AECEnv)
):
raise TypeError("env is not of type expected.")
def _init_env(self):
# Add info
if self.parallel:
_, info_dict = self._reset_parallel(seed=self.seed)
else:
_, info_dict = self._reset_aec(seed=self.seed)
for group, agents in self.group_map.items():
info_specs = []
for agent in agents:
info_specs.append(
CompositeSpec(
{
"info": CompositeSpec(
{
key: UnboundedContinuousTensorSpec(
shape=torch.as_tensor(value).shape,
device=self.device,
)
for key, value in info_dict[agent].items()
}
)
},
device=self.device,
)
)
info_specs = torch.stack(info_specs, dim=0)
if ("info", "action_mask") in info_specs.keys(True, True):
if not self.has_action_mask[group]:
self.has_action_mask[group] = True
group_action_spec = self.input_spec[
"full_action_spec", group, "action"
]
self.observation_spec[group]["action_mask"] = DiscreteTensorSpec(
n=2,
shape=group_action_spec.shape
if not self.categorical_actions
else (*group_action_spec.shape, group_action_spec.space.n),
dtype=torch.bool,
device=self.device,
)
group_inner_info_spec = info_specs["info"]
del group_inner_info_spec["action_mask"]
if len(info_specs["info"].keys()):
self.observation_spec[group].update(info_specs)
if self.return_state:
try:
state_spec = _gym_to_torchrl_spec_transform(
self.state_space,
remap_state_to_observation=False,
device=self.device,
)
except AttributeError:
state_example = torch.as_tensor(self.state(), device=self.device)
state_spec = UnboundedContinuousTensorSpec(
shape=state_example.shape,
dtype=state_example.dtype,
device=self.device,
)
self.observation_spec["state"] = state_spec
# Caching
self.cached_reset_output_zero = self.observation_spec.zero()
self.cached_reset_output_zero.update(self.output_spec["full_done_spec"].zero())
self.cached_step_output_zero = self.observation_spec.zero()
self.cached_step_output_zero.update(self.output_spec["full_reward_spec"].zero())
self.cached_step_output_zero.update(self.output_spec["full_done_spec"].zero())
def _set_seed(self, seed: int):
self.seed = seed
self.reset(seed=self.seed)
def _reset(
self, tensordict: TensorDictBase | None = None, **kwargs
) -> TensorDictBase:
if tensordict is not None:
_reset = tensordict.get("_reset", None)
if _reset is not None and not _reset.all():
raise RuntimeError(
f"An attempt to call {type(self)}._reset was made when no "
f"reset signal could be found. Expected '_reset' entry to "
f"be `tensor(True)` or `None` but got `{_reset}`."
)
if self.parallel:
# This resets when any is done
observation_dict, info_dict = self._reset_parallel(**kwargs)
else:
# This resets when all are done
observation_dict, info_dict = self._reset_aec(**kwargs)
# We start with zeroed data and fill in the data for alive agents
tensordict_out = self.cached_reset_output_zero.clone()
# Update the "mask" for non-acting agents
self._update_agent_mask(tensordict_out)
# Update the "action_mask" for non-available actions
observation_dict, info_dict = self._update_action_mask(
tensordict_out, observation_dict, info_dict
)
# Now we get the data (obs and info)
for group, agent_names in self.group_map.items():
group_observation = tensordict_out.get((group, "observation"))
group_info = tensordict_out.get((group, "info"), None)
for index, agent in enumerate(agent_names):
group_observation[index] = self.observation_spec[group, "observation"][
index
].encode(observation_dict[agent])
if group_info is not None:
agent_info_dict = info_dict[agent]
for agent_info, value in agent_info_dict.items():
group_info.get(agent_info)[index] = torch.as_tensor(
value, device=self.device
)
return tensordict_out
def _reset_aec(self, **kwargs) -> Tuple[Dict, Dict]:
self._env.reset(**kwargs)
observation_dict = {
agent: self._env.observe(agent) for agent in self.possible_agents
}
info_dict = self._env.infos
return observation_dict, info_dict
def _reset_parallel(self, **kwargs) -> Tuple[Dict, Dict]:
return self._env.reset(**kwargs)
def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
if self.parallel:
(
observation_dict,
rewards_dict,
terminations_dict,
truncations_dict,
info_dict,
) = self._step_parallel(tensordict)
else:
(
observation_dict,
rewards_dict,
terminations_dict,
truncations_dict,
info_dict,
) = self._step_aec(tensordict)
# We start with zeroed data and fill in the data for alive agents
tensordict_out = self.cached_step_output_zero.clone()
# Update the "mask" for non-acting agents
self._update_agent_mask(tensordict_out)
# Update the "action_mask" for non-available actions
observation_dict, info_dict = self._update_action_mask(
tensordict_out, observation_dict, info_dict
)
# Now we get the data
for group, agent_names in self.group_map.items():
group_observation = tensordict_out.get((group, "observation"))
group_reward = tensordict_out.get((group, "reward"))
group_done = tensordict_out.get((group, "done"))
group_terminated = tensordict_out.get((group, "terminated"))
group_truncated = tensordict_out.get((group, "truncated"))
group_info = tensordict_out.get((group, "info"), None)
for index, agent in enumerate(agent_names):
if agent in observation_dict: # Live agents
group_observation[index] = self.observation_spec[
group, "observation"
][index].encode(observation_dict[agent])
group_reward[index] = torch.tensor(
rewards_dict[agent],
device=self.device,
dtype=torch.float32,
)
group_done[index] = torch.tensor(
terminations_dict[agent] or truncations_dict[agent],
device=self.device,
dtype=torch.bool,
)
group_truncated[index] = torch.tensor(
truncations_dict[agent],
device=self.device,
dtype=torch.bool,
)
group_terminated[index] = torch.tensor(
terminations_dict[agent],
device=self.device,
dtype=torch.bool,
)
if group_info is not None:
agent_info_dict = info_dict[agent]
for agent_info, value in agent_info_dict.items():
group_info.get(agent_info)[index] = torch.tensor(
value, device=self.device
)
elif self.use_mask:
if agent in self.agents:
raise ValueError(
f"Dead agent {agent} not found in step observation but still available in {self.agents}"
)
# Dead agent
terminated = (
terminations_dict[agent] if agent in terminations_dict else True
)
truncated = (
truncations_dict[agent] if agent in truncations_dict else True
)
done = terminated or truncated
group_done[index] = done
group_terminated[index] = terminated
group_truncated[index] = truncated
else:
# Dead agent, if we are not masking it out, this is not allowed
raise ValueError(
"Dead agents found in the environment,"
" you need to set use_mask=True to allow this."
)
# set done values
done, terminated, truncated = self._aggregate_done(
tensordict_out, use_any=self.done_on_any
)
tensordict_out.set("done", done)
tensordict_out.set("terminated", terminated)
tensordict_out.set("truncated", truncated)
return tensordict_out
def _aggregate_done(self, tensordict_out, use_any):
done = False if use_any else True
truncated = False if use_any else True
terminated = False if use_any else True
for key in self.done_keys:
if isinstance(key, tuple): # Only look at group keys
if use_any:
if key[-1] == "done":
done = done | tensordict_out.get(key).any()
if key[-1] == "terminated":
terminated = terminated | tensordict_out.get(key).any()
if key[-1] == "truncated":
truncated = truncated | tensordict_out.get(key).any()
if done and terminated and truncated:
# no need to proceed further, all values are flipped
break
else:
if key[-1] == "done":
done = done & tensordict_out.get(key).all()
if key[-1] == "terminated":
terminated = terminated & tensordict_out.get(key).all()
if key[-1] == "truncated":
truncated = truncated & tensordict_out.get(key).all()
if not done and not terminated and not truncated:
# no need to proceed further, all values are flipped
break
return (
torch.tensor([done], device=self.device),
torch.tensor([terminated], device=self.device),
torch.tensor([truncated], device=self.device),
)
def _step_parallel(
self,
tensordict: TensorDictBase,
) -> Tuple[Dict, Dict, Dict, Dict, Dict]:
action_dict = {}
for group, agents in self.group_map.items():
group_action = tensordict.get((group, "action"))
group_action_np = self.input_spec[
"full_action_spec", group, "action"
].to_numpy(group_action)
for index, agent in enumerate(agents):
action_dict[agent] = group_action_np[index]
return self._env.step(action_dict)
def _step_aec(
self,
tensordict: TensorDictBase,
) -> Tuple[Dict, Dict, Dict, Dict, Dict]:
for group, agents in self.group_map.items():
if self.agent_selection in agents:
agent_index = agents.index(self._env.agent_selection)
group_action = tensordict.get((group, "action"))
group_action_np = self.input_spec[
"full_action_spec", group, "action"
].to_numpy(group_action)
action = group_action_np[agent_index]
break
self._env.step(action)
terminations_dict = self._env.terminations
truncations_dict = self._env.truncations
info_dict = self._env.infos
rewards_dict = self._env.rewards
observation_dict = {
agent: self._env.observe(agent) for agent in self.possible_agents
}
return (
observation_dict,
rewards_dict,
terminations_dict,
truncations_dict,
info_dict,
)
def _update_action_mask(self, td, observation_dict, info_dict):
# Since we remove the action_mask keys we need to copy the data
observation_dict = copy.deepcopy(observation_dict)
info_dict = copy.deepcopy(info_dict)
# In AEC only one agent acts, in parallel env self.agents contains the agents alive
agents_acting = self.agents if self.parallel else [self.agent_selection]
for group, agents in self.group_map.items():
if self.has_action_mask[group]:
group_mask = td.get((group, "action_mask"))
group_mask += True
for index, agent in enumerate(agents):
agent_obs = observation_dict[agent]
agent_info = info_dict[agent]
if isinstance(agent_obs, Dict) and "action_mask" in agent_obs:
if agent in agents_acting:
group_mask[index] = torch.tensor(
agent_obs["action_mask"],
device=self.device,
dtype=torch.bool,
)
del agent_obs["action_mask"]
elif isinstance(agent_info, Dict) and "action_mask" in agent_info:
if agent in agents_acting:
group_mask[index] = torch.tensor(
agent_info["action_mask"],
device=self.device,
dtype=torch.bool,
)
del agent_info["action_mask"]
group_action_spec = self.input_spec["full_action_spec", group, "action"]
if isinstance(
group_action_spec, (DiscreteTensorSpec, OneHotDiscreteTensorSpec)
):
# We update the mask for available actions
group_action_spec.update_mask(group_mask.clone())
return observation_dict, info_dict
def _update_agent_mask(self, td):
if self.use_mask:
# In AEC only one agent acts, in parallel env self.agents contains the agents alive
agents_acting = self.agents if self.parallel else [self.agent_selection]
for group, agents in self.group_map.items():
group_mask = td.get((group, "mask"))
group_mask += True
# We now add dead agents to the mask
for index, agent in enumerate(agents):
if agent not in agents_acting:
group_mask[index] = False
def close(self) -> None:
self._env.close()
[docs]class PettingZooEnv(PettingZooWrapper):
"""PettingZoo Environment.
To install petting zoo follow the guide `here <https://github.com/Farama-Foundation/PettingZoo#installation>__`.
This class is a general torchrl wrapper for all PettingZoo environments.
It is able to wrap both ``pettingzoo.AECEnv`` and ``pettingzoo.ParallelEnv``.
Let's see how more in details:
For wrapping ``pettingzoo.ParallelEnv`` provide the name of your petting zoo task (in the ``task`` argument)
and specify ``parallel=True``. This will construct the ``pettingzoo.ParallelEnv`` version of that task
(if it is supported in pettingzoo) and wrap it for torchrl.
In wrapped ``pettingzoo.ParallelEnv`` all agents will step at each environment step.
If the number of agents during the task varies, please set ``use_mask=True``.
``"mask"`` will be provided
as an output in each group and should be used to mask out dead agents.
The environment will be reset as soon as one agent is done (unless ``done_on_any`` is ``False``).
For wrapping ``pettingzoo.AECEnv`` provide the name of your petting zoo task (in the ``task`` argument)
and specify ``parallel=False``. This will construct the ``pettingzoo.AECEnv`` version of that task
and wrap it for torchrl.
In wrapped ``pettingzoo.AECEnv``, at each step only one agent will act.
For this reason, it is compulsory to set ``use_mask=True`` for this type of environment.
``"mask"`` will be provided as an output for each group and can be used to mask out non-acting agents.
The environment will be reset only when all agents are done (unless ``done_on_any`` is ``True``).
If there are any unavailable actions for an agent,
the environment will also automatically update the mask of its ``action_spec`` and output an ``"action_mask"``
for each group to reflect the latest available actions. This should be passed to a masked distribution during
training.
As a feature of torchrl multiagent, you are able to control the grouping of agents in your environment.
You can group agents together (stacking their tensors) to leverage vectorization when passing them through the same
neural network. You can split agents in different groups where they are heterogenous or should be processed by
different neural networks. To group, you just need to pass a ``group_map`` at env constructiuon time.
By default, agents in pettingzoo will be grouped by name.
For example, with agents ``["agent_0","agent_1","agent_2","adversary_0"]``, the tensordicts will look like:
>>> print(env.rand_action(env.reset()))
TensorDict(
fields={
agent: TensorDict(
fields={
action: Tensor(shape=torch.Size([3, 9]), device=cpu, dtype=torch.int64, is_shared=False),
action_mask: Tensor(shape=torch.Size([3, 9]), device=cpu, dtype=torch.bool, is_shared=False),
done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([3, 3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False),
terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([3]))},
adversary: TensorDict(
fields={
action: Tensor(shape=torch.Size([1, 9]), device=cpu, dtype=torch.int64, is_shared=False),
action_mask: Tensor(shape=torch.Size([1, 9]), device=cpu, dtype=torch.bool, is_shared=False),
done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([1, 3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False),
terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([1]))},
batch_size=torch.Size([]))
>>> print(env.group_map)
{"agent": ["agent_0", "agent_1", "agent_2"], "adversary": ["adversary_0"]}
Otherwise, a group map can be specified or selected from some premade options.
See :class:`torchrl.envs.utils.MarlGroupMapType` for more info.
For example, you can provide ``MarlGroupMapType.ONE_GROUP_PER_AGENT``, telling that each agent should
have its own tensordict (similar to the pettingzoo parallel API).
Grouping is useful for leveraging vectorisation among agents whose data goes through the same
neural network.
Args:
task (str): the name of the pettingzoo task to create in the "<env>/<task>" format (for example, "sisl/multiwalker_v9")
or "<task>" format (for example, "multiwalker_v9").
parallel (bool): if to construct the ``pettingzoo.ParallelEnv`` version of the task or the ``pettingzoo.AECEnv``.
return_state (bool, optional): whether to return the global state from pettingzoo
(not available in all environments). Defaults to ``False``.
group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to group agents in tensordicts for
input/output. By default, agents will be grouped by their name. Otherwise, a group map can be specified
or selected from some premade options. See :class:`torchrl.envs.utils.MarlGroupMapType` for more info.
use_mask (bool, optional): whether the environment should output an ``"mask"``. This is compulsory in
wrapped ``pettingzoo.AECEnv`` to mask out non-acting agents and should be also used
for ``pettingzoo.ParallelEnv`` when the number of agents can vary. Defaults to ``False``.
categorical_actions (bool, optional): if the enviornments actions are discrete, whether to transform
them to categorical or one-hot.
seed (int, optional): the seed. Defaults to ``None``.
done_on_any (bool, optional): whether the environment's done keys are set by aggregating the agent keys
using ``any()`` (when ``True``) or ``all()`` (when ``False``). Default (``None``) is to use ``any()`` for
parallel environments and ``all()`` for AEC ones.
Examples:
>>> # Parallel env
>>> from torchrl.envs.libs.pettingzoo import PettingZooEnv
>>> kwargs = {"n_pistons": 21, "continuous": True}
>>> env = PettingZooEnv(
... task="pistonball_v6",
... parallel=True,
... return_state=True,
... group_map=None, # Use default (all pistons grouped together)
... **kwargs,
... )
>>> print(env.group_map)
... {'piston': ['piston_0', 'piston_1', ..., 'piston_20']}
>>> env.rollout(10)
>>> # AEC env
>>> from torchrl.envs.libs.pettingzoo import PettingZooEnv
>>> from torchrl.envs.utils import MarlGroupMapType
>>> env = PettingZooEnv(
... task="tictactoe_v3",
... parallel=False,
... use_mask=True, # Must use it since one player plays at a time
... group_map=None # # Use default for AEC (one group per player)
... )
>>> print(env.group_map)
... {'player_1': ['player_1'], 'player_2': ['player_2']}
>>> env.rollout(10)
"""
def __init__(
self,
task: str,
parallel: bool,
return_state: bool = False,
group_map: MarlGroupMapType | Dict[str, List[str]] | None = None,
use_mask: bool = False,
categorical_actions: bool = True,
seed: int | None = None,
done_on_any: bool | None = None,
**kwargs,
):
if not _has_pettingzoo:
raise ImportError(
f"pettingzoo python package was not found. Please install this dependency. "
f"More info: {self.git_url}."
)
kwargs["task"] = task
kwargs["parallel"] = parallel
kwargs["return_state"] = return_state
kwargs["group_map"] = group_map
kwargs["use_mask"] = use_mask
kwargs["categorical_actions"] = categorical_actions
kwargs["seed"] = seed
kwargs["done_on_any"] = done_on_any
super().__init__(**kwargs)
def _check_kwargs(self, kwargs: Dict):
if "task" not in kwargs:
raise TypeError("Could not find environment key 'task' in kwargs.")
if "parallel" not in kwargs:
raise TypeError("Could not find environment key 'parallel' in kwargs.")
def _build_env(
self,
task: str,
parallel: bool,
**kwargs,
) -> Union[
"pettingzoo.utils.env.ParallelEnv", # noqa: F821
"pettingzoo.utils.env.AECEnv", # noqa: F821
]:
self.task_name = task
try:
from pettingzoo.utils.all_modules import all_environments
except ModuleNotFoundError as err:
warnings.warn(
f"PettingZoo failed to load all modules with error message {err}, trying to load individual modules."
)
all_environments = _load_available_envs()
if task not in all_environments:
# Try looking at the literal translation of values
task_module = None
for value in all_environments.values():
if value.__name__.split(".")[-1] == task:
task_module = value
break
if task_module is None:
raise RuntimeError(
f"Specified task not in available environments {all_environments}"
)
else:
task_module = all_environments[task]
if parallel:
petting_zoo_env = task_module.parallel_env(**kwargs)
else:
petting_zoo_env = task_module.env(**kwargs)
return super()._build_env(env=petting_zoo_env)