Shortcuts

Source code for torchrl.envs.custom.chess

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import importlib.util
import io
import pathlib

import torch
from tensordict import TensorDict, TensorDictBase
from torchrl.data.tensor_specs import (
    Binary,
    Bounded,
    Categorical,
    Composite,
    NonTensor,
    Unbounded,
)
from torchrl.envs import EnvBase
from torchrl.envs.common import _EnvPostInit
from torchrl.envs.utils import _classproperty


class _ChessMeta(_EnvPostInit):
    def __call__(cls, *args, **kwargs):
        instance = super().__call__(*args, **kwargs)
        include_hash = kwargs.get("include_hash")
        include_hash_inv = kwargs.get("include_hash_inv")
        if include_hash:
            from torchrl.envs import Hash

            in_keys = []
            out_keys = []
            in_keys_inv = [] if include_hash_inv else None
            out_keys_inv = [] if include_hash_inv else None

            def maybe_add_keys(condition, in_key, out_key):
                if condition:
                    in_keys.append(in_key)
                    out_keys.append(out_key)
                    if include_hash_inv:
                        in_keys_inv.append(in_key)
                        out_keys_inv.append(out_key)

            maybe_add_keys(instance.include_san, "san", "san_hash")
            maybe_add_keys(instance.include_fen, "fen", "fen_hash")
            maybe_add_keys(instance.include_pgn, "pgn", "pgn_hash")

            instance = instance.append_transform(
                Hash(in_keys, out_keys, in_keys_inv, out_keys_inv)
            )
        elif include_hash_inv:
            raise ValueError(
                "'include_hash_inv=True' can only be set if"
                f"'include_hash=True', but got 'include_hash={include_hash}'."
            )
        if kwargs.get("mask_actions", True):
            from torchrl.envs import ActionMask

            instance = instance.append_transform(ActionMask())
        return instance


[docs]class ChessEnv(EnvBase, metaclass=_ChessMeta): r"""A chess environment that follows the TorchRL API. This environment simulates a chess game using the `chess` library. It supports various state representations and can be configured to include different types of observations such as SAN, FEN, PGN, and legal moves. Requires: the `chess` library. More info `here <https://python-chess.readthedocs.io/en/latest/>`__. Args: stateful (bool): Whether to keep track of the internal state of the board. If False, the state will be stored in the observation and passed back to the environment on each call. Default: ``True``. include_san (bool): Whether to include SAN (Standard Algebraic Notation) in the observations. Default: ``False``. .. note:: The `"san"` entry corresponding to `rollout["action"]` will be found in `rollout["next", "san"]`, whereas the value at the root `rollout["san"]` will correspond to the value of the san preceding the same index action. include_fen (bool): Whether to include FEN (Forsyth-Edwards Notation) in the observations. Default: ``False``. include_pgn (bool): Whether to include PGN (Portable Game Notation) in the observations. Default: ``False``. include_legal_moves (bool): Whether to include legal moves in the observations. Default: ``False``. include_hash (bool): Whether to include hash transformations in the environment. Default: ``False``. mask_actions (bool): if ``True``, a :class:`~torchrl.envs.ActionMask` transform will be appended to the env to make sure that the actions are properly masked. Default: ``True``. pixels (bool): Whether to include pixel-based observations of the board. Default: ``False``. .. note:: The action spec is a :class:`~torchrl.data.Categorical` with a number of actions equal to the number of possible SAN moves. The action space is structured as a categorical distribution over all possible SAN moves, with the legal moves being a subset of this space. The environment uses a mask to ensure only legal moves are selected. Examples: >>> import torch >>> from torchrl.envs import ChessEnv >>> _ = torch.manual_seed(0) >>> env = ChessEnv(include_fen=True, include_san=True, include_pgn=True, include_legal_moves=True) >>> print(env) TransformedEnv( env=ChessEnv(), transform=ActionMask(keys=['action', 'action_mask'])) >>> r = env.reset() >>> print(env.rand_step(r)) TensorDict( fields={ action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), action_mask: Tensor(shape=torch.Size([29275]), device=cpu, dtype=torch.bool, is_shared=False), done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1, batch_size=torch.Size([]), device=None), legal_moves: Tensor(shape=torch.Size([219]), device=cpu, dtype=torch.int64, is_shared=False), next: TensorDict( fields={ action_mask: Tensor(shape=torch.Size([29275]), device=cpu, dtype=torch.bool, is_shared=False), done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/5P2/8/PPPPP1PP/RNBQKBNR b KQkq - 0 1, batch_size=torch.Size([]), device=None), legal_moves: Tensor(shape=torch.Size([219]), device=cpu, dtype=torch.int64, is_shared=False), pgn: NonTensorData(data=[Event "?"] [Site "?"] [Date "????.??.??"] [Round "?"] [White "?"] [Black "?"] [Result "*"] 1. f4 *, batch_size=torch.Size([]), device=None), reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), san: NonTensorData(data=f4, batch_size=torch.Size([]), device=None), terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False), pgn: NonTensorData(data=[Event "?"] [Site "?"] [Date "????.??.??"] [Round "?"] [White "?"] [Black "?"] [Result "*"] *, batch_size=torch.Size([]), device=None), san: NonTensorData(data=<start>, batch_size=torch.Size([]), device=None), terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(env.rollout(1000)) TensorDict( fields={ action: Tensor(shape=torch.Size([96]), device=cpu, dtype=torch.int64, is_shared=False), action_mask: Tensor(shape=torch.Size([96, 29275]), device=cpu, dtype=torch.bool, is_shared=False), done: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False), fen: NonTensorStack( ['rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQ..., batch_size=torch.Size([96]), device=None), legal_moves: Tensor(shape=torch.Size([96, 219]), device=cpu, dtype=torch.int64, is_shared=False), next: TensorDict( fields={ action_mask: Tensor(shape=torch.Size([96, 29275]), device=cpu, dtype=torch.bool, is_shared=False), done: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False), fen: NonTensorStack( ['rnbqkbnr/pppppppp/8/8/8/5N2/PPPPPPPP/RNBQKB1R b ..., batch_size=torch.Size([96]), device=None), legal_moves: Tensor(shape=torch.Size([96, 219]), device=cpu, dtype=torch.int64, is_shared=False), pgn: NonTensorStack( ['[Event "?"]\n[Site "?"]\n[Date "????.??.??"]\n[R..., batch_size=torch.Size([96]), device=None), reward: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.float32, is_shared=False), san: NonTensorStack( ['Nf3', 'Na6', 'c4', 'f6', 'h4', 'Rb8', 'Na3', 'Ra..., batch_size=torch.Size([96]), device=None), terminated: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False), turn: Tensor(shape=torch.Size([96]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([96]), device=None, is_shared=False), pgn: NonTensorStack( ['[Event "?"]\n[Site "?"]\n[Date "????.??.??"]\n[R..., batch_size=torch.Size([96]), device=None), san: NonTensorStack( ['<start>', 'Nf3', 'Na6', 'c4', 'f6', 'h4', 'Rb8',..., batch_size=torch.Size([96]), device=None), terminated: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False), turn: Tensor(shape=torch.Size([96]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([96]), device=None, is_shared=False) """ _hash_table: dict[int, str] = {} _PGN_RESTART = """[Event "?"] [Site "?"] [Date "????.??.??"] [Round "?"] [White "?"] [Black "?"] [Result "*"] *""" @_classproperty def lib(cls): try: import chess import chess.pgn except ImportError: raise ImportError( "The `chess` library could not be found. Make sure you installed it through `pip install chess`." ) return chess _san_moves = [] @_classproperty def san_moves(cls): if not cls._san_moves: with open(pathlib.Path(__file__).parent / "san_moves.txt", "r+") as f: cls._san_moves.extend(f.read().split("\n")) return cls._san_moves def _legal_moves_to_index( self, tensordict: TensorDictBase | None = None, board: chess.Board | None = None, # noqa: F821 return_mask: bool = False, pad: bool = False, ) -> torch.Tensor: if not self.stateful: if tensordict is None: # trust the board pass elif self.include_fen: fen = tensordict.get("fen", None) fen = fen.data self.board.set_fen(fen) board = self.board elif self.include_pgn: pgn = tensordict.get("pgn") pgn = pgn.data board = self._pgn_to_board(pgn, self.board) if board is None: board = self.board indices = torch.tensor( [self._san_moves.index(board.san(m)) for m in board.legal_moves], dtype=torch.int64, ) mask = None if return_mask: mask = self._move_index_to_mask(indices) if pad: indices = torch.nn.functional.pad( indices, [0, 218 - indices.numel() + 1], value=len(self.san_moves) ) if return_mask: return indices, mask return indices @classmethod def _move_index_to_mask(cls, indices: torch.Tensor) -> torch.Tensor: return torch.zeros(len(cls.san_moves), dtype=torch.bool).index_fill_( 0, indices, True ) def __init__( self, *, stateful: bool = True, include_san: bool = False, include_fen: bool = False, include_pgn: bool = False, include_legal_moves: bool = False, include_hash: bool = False, include_hash_inv: bool = False, mask_actions: bool = True, pixels: bool = False, ): chess = self.lib super().__init__() self.full_observation_spec = Composite( turn=Categorical(n=2, dtype=torch.bool, shape=()), ) self.include_san = include_san self.include_fen = include_fen self.include_pgn = include_pgn self.mask_actions = mask_actions self.include_legal_moves = include_legal_moves if include_legal_moves: # 218 max possible legal moves per chess board position # https://www.stmintz.com/ccc/index.php?id=424966 # len(self.san_moves)+1 is the padding value self.full_observation_spec["legal_moves"] = Bounded( 0, 1 + len(self.san_moves), shape=(218,), dtype=torch.int64 ) if include_san: self.full_observation_spec["san"] = NonTensor(shape=(), example_data="Nc6") if include_pgn: self.full_observation_spec["pgn"] = NonTensor( shape=(), example_data=self._PGN_RESTART ) if include_fen: self.full_observation_spec["fen"] = NonTensor(shape=(), example_data="any") if not stateful and not (include_pgn or include_fen): raise RuntimeError( "At least one state representation (pgn or fen) must be enabled when stateful " f"is {stateful}." ) self.stateful = stateful # state_spec is loosely defined as such - it's not really an issue that extra keys # can go missing but it allows us to reset the env using fen passed to the reset # method. self.full_state_spec = self.full_observation_spec.clone() self.pixels = pixels if pixels: if importlib.util.find_spec("cairosvg") is None: raise ImportError( "Please install cairosvg to use this environment with pixel rendering." ) if importlib.util.find_spec("torchvision") is None: raise ImportError( "Please install torchvision to use this environment with pixel rendering." ) self.full_observation_spec["pixels"] = Unbounded( shape=(3, 390, 390), dtype=torch.uint8 ) self.full_action_spec = Composite( action=Categorical(n=len(self.san_moves), shape=(), dtype=torch.int64) ) self.full_reward_spec = Composite( reward=Unbounded(shape=(1,), dtype=torch.float32) ) if self.mask_actions: self.full_observation_spec["action_mask"] = Binary( n=len(self.san_moves), dtype=torch.bool ) # done spec generated automatically self.board = chess.Board() if self.stateful: self.action_spec.set_provisional_n(len(list(self.board.legal_moves))) def _is_done(self, board): return board.is_game_over() | board.is_fifty_moves()
[docs] def all_actions(self, tensordict: TensorDictBase | None = None) -> TensorDictBase: if not self.mask_actions: raise RuntimeError( "Cannot generate legal actions since 'mask_actions=False' was " "set. If you really want to generate all actions, not just " "legal ones, call 'env.full_action_spec.enumerate()'." ) return super().all_actions(tensordict)
def _reset(self, tensordict=None): fen = None pgn = None if tensordict is not None: dest = tensordict.empty() if self.include_fen: fen = tensordict.get("fen", None) if fen is not None: fen = fen.data elif self.include_pgn: pgn = tensordict.get("pgn", None) if pgn is not None: pgn = pgn.data else: dest = TensorDict() if fen is None and pgn is None: self.board.reset() elif fen is not None: self.board.set_fen(fen) if self._is_done(self.board): raise ValueError( "Cannot reset to a fen that is a gameover state." f" fen: {fen}" ) elif pgn is not None: self.board = self._pgn_to_board(pgn) if self.include_fen and fen is None: fen = self.board.fen() if self.include_pgn and pgn is None: pgn = self._board_to_pgn(self.board) turn = self.board.turn if self.include_san: if self.board.move_stack: move = self.board.peek() else: move = None if move is None: dest.set("san", "<start>") else: dest.set("san", self.board.san(move)) if self.include_fen: dest.set("fen", fen) if self.include_pgn: dest.set("pgn", pgn) dest.set("turn", turn) if self.include_legal_moves: moves_idx = self._legal_moves_to_index( board=self.board, pad=True, return_mask=self.mask_actions ) if self.mask_actions: moves_idx, mask = moves_idx dest.set("action_mask", mask) dest.set("legal_moves", moves_idx) elif self.mask_actions: dest.set( "action_mask", self._legal_moves_to_index( board=self.board, pad=True, return_mask=True )[1], ) if self.pixels: dest.set("pixels", self._get_tensor_image(board=self.board)) return dest _cairosvg_lib = None @_classproperty def _cairosvg(cls): csvg = cls._cairosvg_lib if csvg is None: import cairosvg csvg = cls._cairosvg_lib = cairosvg return csvg _torchvision_lib = None @_classproperty def _torchvision(cls): tv = cls._torchvision_lib if tv is None: import torchvision tv = cls._torchvision_lib = torchvision return tv @classmethod def _get_tensor_image(cls, board): try: from PIL import Image svg = board._repr_svg_() # Convert SVG to PNG using cairosvg png_data = io.BytesIO() cls._cairosvg.svg2png(bytestring=svg.encode("utf-8"), write_to=png_data) png_data.seek(0) # Open the PNG image using Pillow img = Image.open(png_data) img = cls._torchvision.transforms.functional.pil_to_tensor(img) except ImportError: raise ImportError( "Chess rendering requires cairosvg, PIL and torchvision to be installed." ) return img @classmethod def _pgn_to_board( cls, pgn_string: str, board: chess.Board | None = None # noqa: F821 ) -> chess.Board: # noqa: F821 pgn_io = io.StringIO(pgn_string) game = cls.lib.pgn.read_game(pgn_io) if board is None: board = cls.lib.Board() else: board.reset() for move in game.mainline_moves(): board.push(move) return board @classmethod def _add_move_to_pgn(cls, pgn_string: str, move: chess.Move) -> str: # noqa: F821 pgn_io = io.StringIO(pgn_string) game = cls.lib.pgn.read_game(pgn_io) if game is None: raise ValueError("Invalid PGN string") game.end().add_variation(move) return str(game) @classmethod def _board_to_pgn(cls, board: chess.Board) -> str: # noqa: F821 game = cls.lib.pgn.Game.from_board(board) pgn_string = str(game) return pgn_string def _step(self, tensordict): # action action = tensordict.get("action") board = self.board pgn = None fen = None if not self.stateful: if self.include_fen: fen = tensordict.get("fen").data board.set_fen(fen) elif self.include_pgn: pgn = tensordict.get("pgn").data board = self._pgn_to_board(pgn, board) else: raise RuntimeError( "Not enough information to deduce the board. If stateful=False, include_pgn or include_fen must be True." ) san = self.san_moves[action] board.push_san(san) dest = tensordict.empty() # Collect data if self.include_fen: fen = board.fen() dest.set("fen", fen) if self.include_pgn: if pgn is not None: pgn = self._add_move_to_pgn(pgn, board.move_stack[-1]) else: pgn = self._board_to_pgn(board) dest.set("pgn", pgn) if self.include_san: dest.set("san", san) if self.include_legal_moves: moves_idx = self._legal_moves_to_index( board=board, pad=True, return_mask=self.mask_actions ) if self.mask_actions: moves_idx, mask = moves_idx dest.set("action_mask", mask) dest.set("legal_moves", moves_idx) elif self.mask_actions: dest.set( "action_mask", self._legal_moves_to_index( board=self.board, pad=True, return_mask=True )[1], ) turn = torch.tensor(board.turn) done = self._is_done(board) if board.is_checkmate(): # turn flips after every move, even if the game is over # winner = not turn reward_val = 1 # if winner == self.lib.WHITE else 0 elif done: reward_val = 0.5 else: reward_val = 0.0 reward = torch.tensor([reward_val], dtype=torch.float32) dest.set("reward", reward) dest.set("turn", turn) dest.set("done", [done]) dest.set("terminated", [done]) if self.pixels: dest.set("pixels", self._get_tensor_image(board=self.board)) return dest def _set_seed(self, *args, **kwargs): ...
[docs] def cardinality(self, tensordict: TensorDictBase | None = None) -> int: self._set_action_space(tensordict) return self.action_spec.cardinality()

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources