Shortcuts

Source code for torchrl.envs.custom.tictactoeenv

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from typing import Optional

import torch
from tensordict import TensorDict, TensorDictBase

from torchrl.data.tensor_specs import (
    CompositeSpec,
    DiscreteTensorSpec,
    UnboundedContinuousTensorSpec,
    UnboundedDiscreteTensorSpec,
)
from torchrl.envs.common import EnvBase


[docs]class TicTacToeEnv(EnvBase): """A Tic-Tac-Toe implementation. Keyword Args: single_player (bool, optional): whether one or two players have to be accounted for. ``single_player=True`` means that ``"player1"`` is playing randomly. If ``False`` (default), at each turn, one of the two players has to play. device (torch.device, optional): the device where to put the tensors. Defaults to ``None`` (default device). The environment is stateless. To run it across multiple batches, call >>> env.reset(TensorDict(batch_size=desired_batch_size)) If the ``"mask"`` entry is present, ``rand_action`` takes it into account to generate the next action. Any policy executed on this env should take this mask into account, as well as the turn of the player (stored in the ``"turn"`` output entry). Specs: CompositeSpec( output_spec: CompositeSpec( full_observation_spec: CompositeSpec( board: DiscreteTensorSpec( shape=torch.Size([3, 3]), space=DiscreteBox(n=2), dtype=torch.int32, domain=discrete), turn: DiscreteTensorSpec( shape=torch.Size([1]), space=DiscreteBox(n=2), dtype=torch.int32, domain=discrete), mask: DiscreteTensorSpec( shape=torch.Size([9]), space=DiscreteBox(n=2), dtype=torch.bool, domain=discrete), shape=torch.Size([])), full_reward_spec: CompositeSpec( player0: CompositeSpec( reward: UnboundedContinuousTensorSpec( shape=torch.Size([1]), space=ContinuousBox( low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)), dtype=torch.float32, domain=continuous), shape=torch.Size([])), player1: CompositeSpec( reward: UnboundedContinuousTensorSpec( shape=torch.Size([1]), space=ContinuousBox( low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)), dtype=torch.float32, domain=continuous), shape=torch.Size([])), shape=torch.Size([])), full_done_spec: CompositeSpec( done: DiscreteTensorSpec( shape=torch.Size([1]), space=DiscreteBox(n=2), dtype=torch.bool, domain=discrete), terminated: DiscreteTensorSpec( shape=torch.Size([1]), space=DiscreteBox(n=2), dtype=torch.bool, domain=discrete), truncated: DiscreteTensorSpec( shape=torch.Size([1]), space=DiscreteBox(n=2), dtype=torch.bool, domain=discrete), shape=torch.Size([])), shape=torch.Size([])), input_spec: CompositeSpec( full_state_spec: CompositeSpec( board: DiscreteTensorSpec( shape=torch.Size([3, 3]), space=DiscreteBox(n=2), dtype=torch.int32, domain=discrete), turn: DiscreteTensorSpec( shape=torch.Size([1]), space=DiscreteBox(n=2), dtype=torch.int32, domain=discrete), mask: DiscreteTensorSpec( shape=torch.Size([9]), space=DiscreteBox(n=2), dtype=torch.bool, domain=discrete), shape=torch.Size([])), full_action_spec: CompositeSpec( action: DiscreteTensorSpec( shape=torch.Size([1]), space=DiscreteBox(n=9), dtype=torch.int64, domain=discrete), shape=torch.Size([])), shape=torch.Size([])), shape=torch.Size([])) To run a dummy rollout, execute the following command: Examples: >>> env = TicTacToeEnv() >>> env.rollout(10) TensorDict( fields={ action: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.int64, is_shared=False), board: Tensor(shape=torch.Size([9, 3, 3]), device=cpu, dtype=torch.int32, is_shared=False), done: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.bool, is_shared=False), mask: Tensor(shape=torch.Size([9, 9]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ board: Tensor(shape=torch.Size([9, 3, 3]), device=cpu, dtype=torch.int32, is_shared=False), done: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.bool, is_shared=False), mask: Tensor(shape=torch.Size([9, 9]), device=cpu, dtype=torch.bool, is_shared=False), player0: TensorDict( fields={ reward: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([9]), device=None, is_shared=False), player1: TensorDict( fields={ reward: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([9]), device=None, is_shared=False), terminated: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.bool, is_shared=False), turn: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.int32, is_shared=False)}, batch_size=torch.Size([9]), device=None, is_shared=False), terminated: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.bool, is_shared=False), turn: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.int32, is_shared=False)}, batch_size=torch.Size([9]), device=None, is_shared=False) """ # batch_locked is set to False since various batch sizes can be provided to the env batch_locked: bool = False def __init__(self, *, single_player: bool = False, device=None): super().__init__(device=device) self.single_player = single_player self.action_spec: UnboundedDiscreteTensorSpec = DiscreteTensorSpec( n=9, shape=(), device=device, ) self.full_observation_spec: CompositeSpec = CompositeSpec( board=UnboundedContinuousTensorSpec( shape=(3, 3), dtype=torch.int, device=device ), turn=DiscreteTensorSpec( 2, shape=(1,), dtype=torch.int, device=device, ), mask=DiscreteTensorSpec( 2, shape=(9,), dtype=torch.bool, device=device, ), device=device, ) self.state_spec: CompositeSpec = self.observation_spec.clone() self.reward_spec: UnboundedContinuousTensorSpec = CompositeSpec( { ("player0", "reward"): UnboundedContinuousTensorSpec( shape=(1,), device=device ), ("player1", "reward"): UnboundedContinuousTensorSpec( shape=(1,), device=device ), }, device=device, ) self.full_done_spec: DiscreteTensorSpec = CompositeSpec( done=DiscreteTensorSpec(2, shape=(1,), dtype=torch.bool, device=device), device=device, ) self.full_done_spec["terminated"] = self.full_done_spec["done"].clone() self.full_done_spec["truncated"] = self.full_done_spec["done"].clone() def _reset(self, reset_td: TensorDict) -> TensorDict: shape = reset_td.shape if reset_td is not None else () state = self.state_spec.zero(shape) state["board"] -= 1 state["mask"].fill_(True) return state.update(self.full_done_spec.zero(shape)) def _step(self, state: TensorDict) -> TensorDict: board = state["board"].clone() turn = state["turn"].clone() action = state["action"] board.flatten(-2, -1).scatter_(index=action.unsqueeze(-1), dim=-1, value=1) wins = self.win(state["board"], action) mask = board.flatten(-2, -1) == -1 done = wins | ~mask.any(-1, keepdim=True) terminated = done.clone() reward_0 = wins & (turn == 0) reward_1 = wins & (turn == 1) state = TensorDict( { "done": done, "terminated": terminated, ("player0", "reward"): reward_0.float(), ("player1", "reward"): reward_1.float(), "board": torch.where(board == -1, board, 1 - board), "turn": 1 - state["turn"], "mask": mask, }, batch_size=state.batch_size, ) if self.single_player: select = (~done & (turn == 0)).squeeze(-1) if select.all(): state_select = state elif select.any(): state_select = state[select] else: return state state_select = self._step(self.rand_action(state_select)) if select.all(): return state_select return torch.where(done, state, state_select) return state def _set_seed(self, seed: int | None): ... @staticmethod def win(board: torch.Tensor, action: torch.Tensor): row = action // 3 # type: ignore col = action % 3 # type: ignore return ( board[..., row, :].sum() == 3 | board[..., col].sum() == 3 | board.diagonal(0, -2, -1).sum() == 3 | board.flip(-1).diagonal(0, -2, -1).sum() == 3 ) @staticmethod def full(board: torch.Tensor) -> bool: return torch.sym_int(board.abs().sum()) == 9 @staticmethod def get_action_mask(): pass
[docs] def rand_action(self, tensordict: Optional[TensorDictBase] = None): mask = tensordict.get("mask") action_spec = self.action_spec if tensordict.ndim: action_spec = action_spec.expand(tensordict.shape) else: action_spec = action_spec.clone() action_spec.update_mask(mask) tensordict.set(self.action_key, action_spec.rand()) return tensordict

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources