Shortcuts

Source code for torchrl.envs.libs.openspiel

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import importlib.util
from typing import Dict, List

import torch
from tensordict import TensorDict, TensorDictBase

from torchrl.data.tensor_specs import (
    Categorical,
    Composite,
    NonTensor,
    OneHot,
    Unbounded,
)
from torchrl.envs.common import _EnvWrapper
from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType

_has_pyspiel = importlib.util.find_spec("pyspiel") is not None


def _get_envs():
    if not _has_pyspiel:
        raise ImportError(
            "open_spiel not found. Consider downloading and installing "
            f"open_spiel from {OpenSpielWrapper.git_url}."
        )

    import pyspiel

    return [game.short_name for game in pyspiel.registered_games()]


[docs]class OpenSpielWrapper(_EnvWrapper): """Google DeepMind OpenSpiel environment wrapper. GitHub: https://github.com/google-deepmind/open_spiel Documentation: https://openspiel.readthedocs.io/en/latest/index.html Args: env (pyspiel.State): the game to wrap. Keyword Args: device (torch.device, optional): if provided, the device on which the data is to be cast. Defaults to ``None``. batch_size (torch.Size, optional): the batch size of the environment. Defaults to ``torch.Size([])``. allow_done_after_reset (bool, optional): if ``True``, it is tolerated for envs to be ``done`` just after :meth:`~.reset` is called. Defaults to ``False``. group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to group agents in tensordicts for input/output. See :class:`~torchrl.envs.utils.MarlGroupMapType` for more info. Defaults to :class:`~torchrl.envs.utils.MarlGroupMapType.ALL_IN_ONE_GROUP`. categorical_actions (bool, optional): if ``True``, categorical specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`), otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. return_state (bool, optional): if ``True``, "state" is included in the output of :meth:`~.reset` and :meth:`~step`. The state can be given to :meth:`~.reset` to reset to that state, rather than resetting to the initial state. Defaults to ``False``. Attributes: available_envs: environments available to build Examples: >>> import pyspiel >>> from torchrl.envs import OpenSpielWrapper >>> from tensordict import TensorDict >>> base_env = pyspiel.load_game('chess').new_initial_state() >>> env = OpenSpielWrapper(base_env, return_state=True) >>> td = env.reset() >>> td = env.step(env.full_action_spec.rand()) >>> print(td) TensorDict( fields={ agents: TensorDict( fields={ action: Tensor(shape=torch.Size([2, 4672]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False), next: TensorDict( fields={ agents: TensorDict( fields={ observation: Tensor(shape=torch.Size([2, 20, 8, 8]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2]), device=None, is_shared=False), current_player: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False), done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), state: NonTensorData(data=FEN: rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1 3009 , batch_size=torch.Size([]), device=None), terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(env.available_envs) ['2048', 'add_noise', 'amazons', 'backgammon', ...] :meth:`~.reset` can restore a specific state, rather than the initial state, as long as ``return_state=True``. >>> import pyspiel >>> from torchrl.envs import OpenSpielWrapper >>> from tensordict import TensorDict >>> base_env = pyspiel.load_game('chess').new_initial_state() >>> env = OpenSpielWrapper(base_env, return_state=True) >>> td = env.reset() >>> td = env.step(env.full_action_spec.rand()) >>> td_restore = td["next"] >>> td = env.step(env.full_action_spec.rand()) >>> # Current state is not equal `td_restore` >>> (td["next"] == td_restore).all() False >>> td = env.reset(td_restore) >>> # After resetting, now the current state is equal to `td_restore` >>> (td == td_restore).all() True """ git_url = "https://github.com/google-deepmind/open_spiel" libname = "pyspiel" _lib = None @_classproperty def lib(cls): if cls._lib is not None: return cls._lib import pyspiel cls._lib = pyspiel return pyspiel @_classproperty def available_envs(cls): if not _has_pyspiel: return [] return _get_envs() def __init__( self, env=None, *, group_map: MarlGroupMapType | Dict[str, List[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, categorical_actions: bool = False, return_state: bool = False, **kwargs, ): if env is not None: kwargs["env"] = env self.group_map = group_map self.categorical_actions = categorical_actions self.return_state = return_state self._cached_game = None super().__init__(**kwargs) # `reset` allows resetting to any state, including a terminal state self._allow_done_after_reset = True def _check_kwargs(self, kwargs: Dict): pyspiel = self.lib if "env" not in kwargs: raise TypeError("Could not find environment key 'env' in kwargs.") env = kwargs["env"] if not isinstance(env, pyspiel.State): raise TypeError("env is not of type 'pyspiel.State'.") def _build_env(self, env, requires_grad: bool = False, **kwargs): game = env.get_game() game_type = game.get_type() if game.max_chance_outcomes() != 0: raise NotImplementedError( f"The game '{game_type.short_name}' has chance nodes, which are not yet supported." ) if game_type.dynamics == self.lib.GameType.Dynamics.MEAN_FIELD: # NOTE: It is unclear from the OpenSpiel documentation what exactly # "mean field" means exactly, and there is no documentation on the # several games which have it. raise RuntimeError( f"Mean field games like '{game_type.name}' are not yet " "supported." ) self.parallel = game_type.dynamics == self.lib.GameType.Dynamics.SIMULTANEOUS self.requires_grad = requires_grad return env def _init_env(self): self._update_action_mask() def _get_game(self): if self._cached_game is None: self._cached_game = self._env.get_game() return self._cached_game def _make_group_map(self, group_map, agent_names): if group_map is None: group_map = MarlGroupMapType.ONE_GROUP_PER_AGENT.get_group_map(agent_names) elif isinstance(group_map, MarlGroupMapType): group_map = group_map.get_group_map(agent_names) check_marl_grouping(group_map, agent_names) return group_map def _make_group_specs( self, env, group: str, ): observation_specs = [] action_specs = [] reward_specs = [] game = env.get_game() for _ in self.group_map[group]: observation_spec = Composite() if self.has_observation: observation_spec["observation"] = Unbounded( shape=(*game.observation_tensor_shape(),), device=self.device, domain="continuous", ) if self.has_information_state: observation_spec["information_state"] = Unbounded( shape=(*game.information_state_tensor_shape(),), device=self.device, domain="continuous", ) observation_specs.append(observation_spec) action_spec_cls = Categorical if self.categorical_actions else OneHot action_specs.append( Composite( action=action_spec_cls( env.num_distinct_actions(), dtype=torch.int64, device=self.device, ) ) ) reward_specs.append( Composite( reward=Unbounded( shape=(1,), device=self.device, domain="continuous", ) ) ) group_observation_spec = torch.stack( observation_specs, dim=0 ) # shape = (n_agents, n_obser_per_agent) group_action_spec = torch.stack( action_specs, dim=0 ) # shape = (n_agents, n_actions_per_agent) group_reward_spec = torch.stack(reward_specs, dim=0) # shape = (n_agents, 1) return ( group_observation_spec, group_action_spec, group_reward_spec, ) def _make_specs(self, env: "pyspiel.State") -> None: # noqa: F821 self.agent_names = [f"player_{index}" for index in range(env.num_players())] self.agent_names_to_indices_map = { agent_name: i for i, agent_name in enumerate(self.agent_names) } self.group_map = self._make_group_map(self.group_map, self.agent_names) self.done_spec = Categorical( n=2, shape=torch.Size((1,)), dtype=torch.bool, device=self.device, ) game = env.get_game() game_type = game.get_type() # In OpenSpiel, a game's state may have either an "observation" tensor, # an "information state" tensor, or both. If the OpenSpiel game does not # have one of these, then its corresponding accessor functions raise an # error, so we must avoid calling them. self.has_observation = game_type.provides_observation_tensor self.has_information_state = game_type.provides_information_state_tensor observation_spec = {} action_spec = {} reward_spec = {} for group in self.group_map.keys(): ( group_observation_spec, group_action_spec, group_reward_spec, ) = self._make_group_specs( env, group, ) observation_spec[group] = group_observation_spec action_spec[group] = group_action_spec reward_spec[group] = group_reward_spec if self.return_state: observation_spec["state"] = NonTensor([]) observation_spec["current_player"] = Unbounded( shape=(), dtype=torch.int, device=self.device, domain="discrete", ) self.observation_spec = Composite(observation_spec) self.action_spec = Composite(action_spec) self.reward_spec = Composite(reward_spec) def _set_seed(self, seed): if seed is not None: raise NotImplementedError("This environment has no seed.") def current_player(self): return self._env.current_player() def _update_action_mask(self): if self._env.is_terminal(): agents_acting = [] else: agents_acting = [ self.agent_names if self.parallel else self.agent_names[self._env.current_player()] ] for group, agents in self.group_map.items(): action_masks = [] for agent in agents: agent_index = self.agent_names_to_indices_map[agent] if agent in agents_acting: action_mask = torch.zeros( self._env.num_distinct_actions(), device=self.device, dtype=torch.bool, ) action_mask[self._env.legal_actions(agent_index)] = True else: action_mask = torch.zeros( self._env.num_distinct_actions(), device=self.device, dtype=torch.bool, ) # In OpenSpiel parallel games, non-acting players are # expected to take action 0. # https://openspiel.readthedocs.io/en/latest/api_reference/state_apply_action.html action_mask[0] = True action_masks.append(action_mask) self.full_action_spec[group, "action"].update_mask( torch.stack(action_masks, dim=0) ) def _make_td_out(self, exclude_reward=False): done = torch.tensor( self._env.is_terminal(), device=self.device, dtype=torch.bool ) current_player = torch.tensor( self.current_player(), device=self.device, dtype=torch.int ) source = { "done": done, "terminated": done.clone(), "current_player": current_player, } if self.return_state: source["state"] = self._env.serialize() reward = self._env.returns() for group, agent_names in self.group_map.items(): agent_tds = [] for agent in agent_names: agent_index = self.agent_names_to_indices_map[agent] agent_source = {} if self.has_observation: observation_shape = self._get_game().observation_tensor_shape() agent_source["observation"] = self._to_tensor( self._env.observation_tensor(agent_index) ).reshape(observation_shape) if self.has_information_state: information_state_shape = ( self._get_game().information_state_tensor_shape() ) agent_source["information_state"] = self._to_tensor( self._env.information_state_tensor(agent_index) ).reshape(information_state_shape) if not exclude_reward: agent_source["reward"] = self._to_tensor(reward[agent_index]) agent_td = TensorDict( source=agent_source, batch_size=self.batch_size, device=self.device, ) agent_tds.append(agent_td) source[group] = torch.stack(agent_tds, dim=0) tensordict_out = TensorDict( source=source, batch_size=self.batch_size, device=self.device, ) return tensordict_out def _get_action_from_tensor(self, tensor): if not self.categorical_actions: action = torch.argmax(tensor, dim=-1) else: action = tensor return action def _step_parallel(self, tensordict: TensorDictBase): actions = [0] * self._env.num_players() for group, agents in self.group_map.items(): for index_in_group, agent in enumerate(agents): agent_index = self.agent_names_to_indices_map[agent] action_tensor = tensordict[group, "action"][index_in_group] action = self._get_action_from_tensor(action_tensor) actions[agent_index] = action self._env.apply_actions(actions) def _step_sequential(self, tensordict: TensorDictBase): agent_index = self._env.current_player() # If the game has ended, do nothing if agent_index == self.lib.PlayerId.TERMINAL: return agent = self.agent_names[agent_index] agent_group = None agent_index_in_group = None for group, agents in self.group_map.items(): if agent in agents: agent_group = group agent_index_in_group = agents.index(agent) break assert agent_group is not None action_tensor = tensordict[agent_group, "action"][agent_index_in_group] action = self._get_action_from_tensor(action_tensor) self._env.apply_action(action) def _step(self, tensordict: TensorDictBase) -> TensorDictBase: if self.parallel: self._step_parallel(tensordict) else: self._step_sequential(tensordict) self._update_action_mask() return self._make_td_out() def _to_tensor(self, value): return torch.tensor(value, device=self.device, dtype=torch.float32) def _reset( self, tensordict: TensorDictBase | None = None, **kwargs ) -> TensorDictBase: game = self._get_game() if tensordict is not None and "state" in tensordict: new_env = game.deserialize_state(tensordict["state"]) else: new_env = game.new_initial_state() self._env = new_env self._update_action_mask() return self._make_td_out(exclude_reward=True)
[docs]class OpenSpielEnv(OpenSpielWrapper): """Google DeepMind OpenSpiel environment wrapper built with the game string. GitHub: https://github.com/google-deepmind/open_spiel Documentation: https://openspiel.readthedocs.io/en/latest/index.html Args: game_string (str): the name of the game to wrap. Must be part of :attr:`~.available_envs`. Keyword Args: device (torch.device, optional): if provided, the device on which the data is to be cast. Defaults to ``None``. batch_size (torch.Size, optional): the batch size of the environment. Defaults to ``torch.Size([])``. allow_done_after_reset (bool, optional): if ``True``, it is tolerated for envs to be ``done`` just after :meth:`~.reset` is called. Defaults to ``False``. group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to group agents in tensordicts for input/output. See :class:`~torchrl.envs.utils.MarlGroupMapType` for more info. Defaults to :class:`~torchrl.envs.utils.MarlGroupMapType.ALL_IN_ONE_GROUP`. categorical_actions (bool, optional): if ``True``, categorical specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`), otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. return_state (bool, optional): if ``True``, "state" is included in the output of :meth:`~.reset` and :meth:`~step`. The state can be given to :meth:`~.reset` to reset to that state, rather than resetting to the initial state. Defaults to ``False``. Attributes: available_envs: environments available to build Examples: >>> from torchrl.envs import OpenSpielEnv >>> from tensordict import TensorDict >>> env = OpenSpielEnv("chess", return_state=True) >>> td = env.reset() >>> td = env.step(env.full_action_spec.rand()) >>> print(td) TensorDict( fields={ agents: TensorDict( fields={ action: Tensor(shape=torch.Size([2, 4672]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False), next: TensorDict( fields={ agents: TensorDict( fields={ observation: Tensor(shape=torch.Size([2, 20, 8, 8]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2]), device=None, is_shared=False), current_player: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False), done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), state: NonTensorData(data=FEN: rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1 674 , batch_size=torch.Size([]), device=None), terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(env.available_envs) ['2048', 'add_noise', 'amazons', 'backgammon', ...] :meth:`~.reset` can restore a specific state, rather than the initial state, as long as ``return_state=True``. >>> from torchrl.envs import OpenSpielEnv >>> from tensordict import TensorDict >>> env = OpenSpielEnv("chess", return_state=True) >>> td = env.reset() >>> td = env.step(env.full_action_spec.rand()) >>> td_restore = td["next"] >>> td = env.step(env.full_action_spec.rand()) >>> # Current state is not equal `td_restore` >>> (td["next"] == td_restore).all() False >>> td = env.reset(td_restore) >>> # After resetting, now the current state is equal to `td_restore` >>> (td == td_restore).all() True """ def __init__( self, game_string, *, group_map: MarlGroupMapType | Dict[str, List[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, categorical_actions=False, return_state: bool = False, **kwargs, ): kwargs["game_string"] = game_string super().__init__( group_map=group_map, categorical_actions=categorical_actions, return_state=return_state, **kwargs, ) def _build_env( self, game_string: str, **kwargs, ) -> "pyspiel.State": # noqa: F821 if not _has_pyspiel: raise ImportError( f"open_spiel not found, unable to create {game_string}. Consider " f"downloading and installing open_spiel from {self.git_url}" ) requires_grad = kwargs.pop("requires_grad", False) parameters = kwargs.pop("parameters", None) if kwargs: raise ValueError("kwargs not supported.") if parameters: game = self.lib.load_game(game_string, parameters=parameters) else: game = self.lib.load_game(game_string) env = game.new_initial_state() return super()._build_env( env, requires_grad=requires_grad, ) @property def game_string(self): return self._constructor_kwargs["game_string"] def _check_kwargs(self, kwargs: Dict): if "game_string" not in kwargs: raise TypeError("Expected 'game_string' to be part of kwargs") def __repr__(self) -> str: return f"{self.__class__.__name__}(env={self.game_string}, batch_size={self.batch_size}, device={self.device})"

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources