Shortcuts

Source code for torchrl.data.datasets.atari_dqn

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import functools
import gzip
import io
import json
import os
import shutil
import subprocess
import tempfile
from collections import defaultdict
from pathlib import Path
from typing import Callable

import numpy as np
import torch
from tensordict import MemoryMappedTensor, TensorDict, TensorDictBase
from torch import multiprocessing as mp
from torchrl._utils import logger as torchrl_logger
from torchrl.data.datasets.common import BaseDatasetExperienceReplay

from torchrl.data.replay_buffers.samplers import (
    SamplerWithoutReplacement,
    SliceSampler,
    SliceSamplerWithoutReplacement,
)
from torchrl.data.replay_buffers.storages import Storage, TensorStorage
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter
from torchrl.data.utils import CloudpickleWrapper
from torchrl.envs.utils import _classproperty


[docs]class AtariDQNExperienceReplay(BaseDatasetExperienceReplay): """Atari DQN Experience replay class. The Atari DQN dataset (https://offline-rl.github.io/) is a collection of 5 training iterations of DQN over each of the Arari 2600 games for a total of 200 million frames. The sub-sampling rate (frame-skip) is equal to 4, meaning that each game dataset has 50 million steps in total. The data format follows the :ref:`TED convention <TED-format>`. Since the dataset is quite heavy, the data formatting is done on-line, at sampling time. To make training more modular, we split the dataset in each of the Atari games and separate each training round. Consequently, each dataset is presented as a Storage of length 50x10^6 elements. Under the hood, this dataset is split in 50 memory-mapped tensordicts of length 1 million each. Args: dataset_id (str): The dataset to be downloaded. Must be part of ``AtariDQNExperienceReplay.available_datasets``. batch_size (int): Batch-size used during sampling. Can be overridden by `data.sample(batch_size)` if necessary. Keyword Args: root (Path or str, optional): The AtariDQN dataset root directory. The actual dataset memory-mapped files will be saved under `<root>/<dataset_id>`. If none is provided, it defaults to ``~/.cache/torchrl/atari`. num_procs (int, optional): number of processes to launch for preprocessing. Has no effect whenever the data is already downloaded. Defaults to 0 (no multiprocessing used). download (bool or str, optional): Whether the dataset should be downloaded if not found. Defaults to ``True``. Download can also be passed as ``"force"``, in which case the downloaded data will be overwritten. sampler (Sampler, optional): the sampler to be used. If none is provided a default RandomSampler() will be used. writer (Writer, optional): the writer to be used. If none is provided a default :class:`~torchrl.data.replay_buffers.writers.ImmutableDatasetWriter` will be used. collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s)/outputs. Used when using batched loading from a map-style dataset. pin_memory (bool): whether pin_memory() should be called on the rb samples. prefetch (int, optional): number of next batches to be prefetched using multithreading. transform (Transform, optional): Transform to be executed when sample() is called. To chain transforms use the :class:`~torchrl.envs.transforms.transforms.Compose` class. num_slices (int, optional): the number of slices to be sampled. The batch-size must be greater or equal to the ``num_slices`` argument. Exclusive with ``slice_len``. Defaults to ``None`` (no slice sampling). The ``sampler`` arg will override this value. slice_len (int, optional): the length of the slices to be sampled. The batch-size must be greater or equal to the ``slice_len`` argument and divisible by it. Exclusive with ``num_slices``. Defaults to ``None`` (no slice sampling). The ``sampler`` arg will override this value. strict_length (bool, optional): if ``False``, trajectories of length shorter than `slice_len` (or `batch_size // num_slices`) will be allowed to appear in the batch. Be mindful that this can result in effective `batch_size` shorter than the one asked for! Trajectories can be split using :func:`torchrl.collectors.split_trajectories`. Defaults to ``True``. The ``sampler`` arg will override this value. replacement (bool, optional): if ``False``, sampling will occur without replacement. The ``sampler`` arg will override this value. mp_start_method (str, optional): the start method for multiprocessed download. Defaults to ``"fork"``. Attributes: available_datasets: list of available datasets, formatted as `<game_name>/<run>`. Example: `"Pong/5"`, `"Krull/2"`, ... dataset_id (str): the name of the dataset. episodes (torch.Tensor): a 1d tensor indicating to what run each of the 1M frames belongs. To be used with :class:`~torchrl.data.replay_buffers.SliceSampler` to cheaply sample slices of episodes. Examples: >>> from torchrl.data.datasets import AtariDQNExperienceReplay >>> dataset = AtariDQNExperienceReplay("Pong/5", batch_size=128) >>> for data in dataset: ... print(data) ... break TensorDict( fields={ action: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int32, is_shared=False), done: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False), index: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int64, is_shared=False), metadata: NonTensorData( data={'invalid_range': MemoryMappedTensor([999998, 999999, 0, 1, 2]), 'add_count': MemoryMappedTensor(999999), 'dataset_id': 'Pong/5'}}, batch_size=torch.Size([128]), device=None, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False), observation: Tensor(shape=torch.Size([128, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False), reward: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False), truncated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False)}, batch_size=torch.Size([128]), device=None, is_shared=False), observation: Tensor(shape=torch.Size([128, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False), terminated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False), truncated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False)}, batch_size=torch.Size([128]), device=None, is_shared=False) .. warning:: Atari-DQN does not provide the next observation after a termination signal. In other words, there is no way to obtain the ``("next", "observation")`` state when ``("next", "done")`` is ``True``. This value is filled with 0s but should not be used in practice. If TorchRL's value estimators (:class:`~torchrl.objectives.values.ValueEstimator`) are used, this should not be an issue. .. note:: Because the construction of the sampler for episode sampling is slightly convoluted, we made it easy for users to pass the arguments of the :class:`~torchrl.data.replay_buffers.SliceSampler` directly to the ``AtariDQNExperienceReplay`` dataset: any of the ``num_slices`` or ``slice_len`` arguments will make the sampler an instance of :class:`~torchrl.data.replay_buffers.SliceSampler`. The ``strict_length`` can also be passed. >>> from torchrl.data.datasets import AtariDQNExperienceReplay >>> from torchrl.data.replay_buffers import SliceSampler >>> dataset = AtariDQNExperienceReplay("Pong/5", batch_size=128, slice_len=64) >>> for data in dataset: ... print(data) ... print(data.get("index")) # indices are in 4 groups of consecutive values ... break TensorDict( fields={ action: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int32, is_shared=False), done: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False), index: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int64, is_shared=False), metadata: NonTensorData( data={'invalid_range': MemoryMappedTensor([999998, 999999, 0, 1, 2]), 'add_count': MemoryMappedTensor(999999), 'dataset_id': 'Pong/5'}}, batch_size=torch.Size([128]), device=None, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([128, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([128, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False), reward: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([128, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([128, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([128]), device=None, is_shared=False), observation: Tensor(shape=torch.Size([128, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False), terminated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False), truncated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False)}, batch_size=torch.Size([128]), device=None, is_shared=False) tensor([2657628, 2657629, 2657630, 2657631, 2657632, 2657633, 2657634, 2657635, 2657636, 2657637, 2657638, 2657639, 2657640, 2657641, 2657642, 2657643, 2657644, 2657645, 2657646, 2657647, 2657648, 2657649, 2657650, 2657651, 2657652, 2657653, 2657654, 2657655, 2657656, 2657657, 2657658, 2657659, 2657660, 2657661, 2657662, 2657663, 2657664, 2657665, 2657666, 2657667, 2657668, 2657669, 2657670, 2657671, 2657672, 2657673, 2657674, 2657675, 2657676, 2657677, 2657678, 2657679, 2657680, 2657681, 2657682, 2657683, 2657684, 2657685, 2657686, 2657687, 2657688, 2657689, 2657690, 2657691, 1995687, 1995688, 1995689, 1995690, 1995691, 1995692, 1995693, 1995694, 1995695, 1995696, 1995697, 1995698, 1995699, 1995700, 1995701, 1995702, 1995703, 1995704, 1995705, 1995706, 1995707, 1995708, 1995709, 1995710, 1995711, 1995712, 1995713, 1995714, 1995715, 1995716, 1995717, 1995718, 1995719, 1995720, 1995721, 1995722, 1995723, 1995724, 1995725, 1995726, 1995727, 1995728, 1995729, 1995730, 1995731, 1995732, 1995733, 1995734, 1995735, 1995736, 1995737, 1995738, 1995739, 1995740, 1995741, 1995742, 1995743, 1995744, 1995745, 1995746, 1995747, 1995748, 1995749, 1995750]) .. note:: As always, datasets should be composed using :class:`~torchrl.data.replay_buffers.ReplayBufferEnsemble`: >>> from torchrl.data.datasets import AtariDQNExperienceReplay >>> from torchrl.data.replay_buffers import ReplayBufferEnsemble >>> # we change this parameter for quick experimentation, in practice it should be left untouched >>> AtariDQNExperienceReplay._max_runs = 2 >>> dataset_asterix = AtariDQNExperienceReplay("Asterix/5", batch_size=128, slice_len=64, num_procs=4) >>> dataset_pong = AtariDQNExperienceReplay("Pong/5", batch_size=128, slice_len=64, num_procs=4) >>> dataset = ReplayBufferEnsemble(dataset_pong, dataset_asterix, batch_size=128, sample_from_all=True) >>> sample = dataset.sample() >>> print("first sample, Asterix", sample[0]) first sample, Asterix TensorDict( fields={ action: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int32, is_shared=False), done: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False), index: TensorDict( fields={ buffer_ids: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False), index: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([64]), device=None, is_shared=False), metadata: NonTensorData( data={'invalid_range': MemoryMappedTensor([999998, 999999, 0, 1, 2]), 'add_count': MemoryMappedTensor(999999), 'dataset_id': 'Pong/5'}, batch_size=torch.Size([64]), device=None, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False), reward: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([64]), device=None, is_shared=False), observation: Tensor(shape=torch.Size([64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False), terminated: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False), truncated: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False)}, batch_size=torch.Size([64]), device=None, is_shared=False) >>> print("second sample, Pong", sample[1]) second sample, Pong TensorDict( fields={ action: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int32, is_shared=False), done: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False), index: TensorDict( fields={ buffer_ids: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False), index: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([64]), device=None, is_shared=False), metadata: NonTensorData( data={'invalid_range': MemoryMappedTensor([999998, 999999, 0, 1, 2]), 'add_count': MemoryMappedTensor(999999), 'dataset_id': 'Asterix/5'}, batch_size=torch.Size([64]), device=None, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False), reward: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([64]), device=None, is_shared=False), observation: Tensor(shape=torch.Size([64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False), terminated: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False), truncated: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False)}, batch_size=torch.Size([64]), device=None, is_shared=False) >>> print("Aggregate (metadata hidden)", sample) Aggregate (metadata hidden) LazyStackedTensorDict( fields={ action: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.int32, is_shared=False), done: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.uint8, is_shared=False), index: LazyStackedTensorDict( fields={ buffer_ids: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.int64, is_shared=False), index: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.int64, is_shared=False)}, exclusive_fields={ }, batch_size=torch.Size([2, 64]), device=None, is_shared=False, stack_dim=0), metadata: LazyStackedTensorDict( fields={ }, exclusive_fields={ }, batch_size=torch.Size([2, 64]), device=None, is_shared=False, stack_dim=0), next: LazyStackedTensorDict( fields={ done: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([2, 64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False), reward: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, exclusive_fields={ }, batch_size=torch.Size([2, 64]), device=None, is_shared=False, stack_dim=0), observation: Tensor(shape=torch.Size([2, 64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False), terminated: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.uint8, is_shared=False), truncated: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.uint8, is_shared=False)}, exclusive_fields={ }, batch_size=torch.Size([2, 64]), device=None, is_shared=False, stack_dim=0) """ @_classproperty def available_datasets(cls): games = [ "AirRaid", "Alien", "Amidar", "Assault", "Asterix", "Asteroids", "Atlantis", "BankHeist", "BattleZone", "BeamRider", "Berzerk", "Bowling", "Boxing", "Breakout", "Carnival", "Centipede", "ChopperCommand", "CrazyClimber", "DemonAttack", "DoubleDunk", "ElevatorAction", "Enduro", "FishingDerby", "Freeway", "Frostbite", "Gopher", "Gravitar", "Hero", "IceHockey", "Jamesbond", "JourneyEscape", "Kangaroo", "Krull", "KungFuMaster", "MontezumaRevenge", "MsPacman", "NameThisGame", "Phoenix", "Pitfall", "Pong", "Pooyan", "PrivateEye", "Qbert", "Riverraid", "RoadRunner", "Robotank", "Seaquest", "Skiing", "Solaris", "SpaceInvaders", ] return ["/".join((game, str(loop))) for game in games for loop in range(1, 6)] # If we want to keep track of the original atari files tmpdir = None # use _max_runs for debugging, avoids downloading the entire dataset _max_runs = None def __init__( self, dataset_id: str, batch_size: int | None = None, *, root: str | Path | None = None, download: bool | str = True, sampler=None, writer=None, transform: "Transform" | None = None, # noqa: F821 num_procs: int = 0, num_slices: int | None = None, slice_len: int | None = None, strict_len: bool = True, replacement: bool = True, mp_start_method: str = "fork", **kwargs, ): if dataset_id not in self.available_datasets: raise ValueError( "The dataseet_id is not part of the available datasets. The dataset should be named <game_name>/<run> " "where <game_name> is one of the Atari 2600 games and the run is a number betweeen 1 and 5. " "The full list of accepted dataset_ids is available under AtariDQNExperienceReplay.available_datasets." ) self.dataset_id = dataset_id from torchrl.data.datasets.utils import _get_root_dir if root is None: root = _get_root_dir("atari") self.root = root self.num_procs = num_procs self.mp_start_method = mp_start_method if download == "force" or (download and not self._is_downloaded): try: self._download_and_preproc() except Exception: # remove temporary data if os.path.exists(self.dataset_path): shutil.rmtree(self.dataset_path) raise if self._downloaded_and_preproc: storage = TensorStorage(TensorDict.load_memmap(self.dataset_path)) else: storage = _AtariStorage(self.dataset_path) if writer is None: writer = ImmutableDatasetWriter() if sampler is None: if num_slices is not None or slice_len is not None: if not replacement: sampler = SliceSamplerWithoutReplacement( num_slices=num_slices, slice_len=slice_len, trajectories=storage.episodes, ) else: sampler = SliceSampler( num_slices=num_slices, slice_len=slice_len, trajectories=storage.episodes, cache_values=True, ) elif not replacement: sampler = SamplerWithoutReplacement() super().__init__( storage=storage, batch_size=batch_size, writer=writer, sampler=sampler, collate_fn=lambda x: x, transform=transform, **kwargs, ) @property def episodes(self): return self._storage.episodes @property def root(self) -> Path: return self._root @root.setter def root(self, value): self._root = Path(value) @property def dataset_path(self) -> Path: return self._root / self.dataset_id @property def _downloaded_and_preproc(self): return os.path.exists(self.dataset_path / "meta.json") @property def _is_downloaded(self): if os.path.exists(self.dataset_path / "meta.json"): return True if os.path.exists(self.dataset_path / "processed.json"): with open(self.dataset_path / "processed.json", "r") as jsonfile: return json.load(jsonfile).get("processed", False) == self._max_runs return False def _download_and_preproc(self): torchrl_logger.info( f"Downloading and preprocessing dataset {self.dataset_id} with {self.num_procs} processes. This may take a while..." ) if os.path.exists(self.dataset_path): shutil.rmtree(self.dataset_path) with tempfile.TemporaryDirectory() as tempdir: if self.tmpdir is not None: tempdir = self.tmpdir if not os.listdir(tempdir): os.makedirs(tempdir, exist_ok=True) # get the list of runs command = f"gsutil -m ls -R gs://atari-replay-datasets/dqn/{self.dataset_id}/replay_logs" output = subprocess.run( command, shell=True, capture_output=True ) # , stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) files = [ file.decode("utf-8").replace("$", "\$") # noqa: W605 for file in output.stdout.splitlines() if file.endswith(b".gz") ] self.remote_gz_files = self._list_runs(None, files) remote_gz_files = list(self.remote_gz_files) if not len(remote_gz_files): raise RuntimeError( "Could not load the file list. Did you install gsutil?" ) total_runs = remote_gz_files[-1] if self.num_procs == 0: for run, run_files in self.remote_gz_files.items(): self._download_and_proc_split( run, run_files, tempdir=tempdir, dataset_path=self.dataset_path, total_episodes=total_runs, max_runs=self._max_runs, multithreaded=True, ) else: func = functools.partial( self._download_and_proc_split, tempdir=tempdir, dataset_path=self.dataset_path, total_episodes=total_runs, max_runs=self._max_runs, multithreaded=False, ) args = [ (run, run_files) for (run, run_files) in self.remote_gz_files.items() ] ctx = mp.get_context(self.mp_start_method) with ctx.Pool(self.num_procs) as pool: pool.starmap(func, args) with open(self.dataset_path / "processed.json", "w") as file: # we save self._max_runs such that changing the number of runs to process # forces the data to be re-downloaded json.dump({"processed": self._max_runs}, file) @classmethod def _download_and_proc_split( cls, run, run_files, *, tempdir, dataset_path, total_episodes, max_runs, multithreaded=True, ): if (max_runs is not None) and (run >= max_runs): return tempdir = Path(tempdir) os.makedirs(tempdir / str(run)) files_str = " ".join(run_files) # .decode("utf-8") torchrl_logger.info(f"downloading {files_str}") if multithreaded: command = f"gsutil -m cp {files_str} {tempdir}/{run}" else: command = f"gsutil cp {files_str} {tempdir}/{run}" subprocess.run( command, shell=True ) # , stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) local_gz_files = cls._list_runs(tempdir / str(run)) # we iterate over the dict but this one has length 1 for run in local_gz_files: path = dataset_path / str(run) try: cls._preproc_run(path, local_gz_files, run) except Exception: shutil.rmtree(path) raise shutil.rmtree(tempdir / str(run)) torchrl_logger.info(f"Concluded run {run} out of {total_episodes}") @classmethod def _preproc_run(cls, path, gz_files, run): files = gz_files[run] td = TensorDict() path = Path(path) for file in files: name = str(Path(file).parts[-1]).split(".")[0] with gzip.GzipFile(file, mode="rb") as f: file_content = f.read() file_content = io.BytesIO(file_content) file_content = np.load(file_content) t = torch.as_tensor(file_content) # Create the memmap file key = cls._process_name(name) if key == ("data", "observation"): shape = t.shape shape = [shape[0] + 1] + list(shape[1:]) filename = path / "data" / "observation.memmap" os.makedirs(filename.parent, exist_ok=True) mmap = MemoryMappedTensor.empty(shape, dtype=t.dtype, filename=filename) mmap[:-1].copy_(t) td[key] = mmap # td["data", "next", key[1:]] = mmap[1:] else: if key in ( ("data", "reward"), ("data", "done"), ("data", "terminated"), ): filename = path / "data" / "next" / (key[-1] + ".memmap") os.makedirs(filename.parent, exist_ok=True) mmap = MemoryMappedTensor.from_tensor(t, filename=filename) td["data", "next", key[1:]] = mmap else: filename = path for i, _key in enumerate(key): if i == len(key) - 1: _key = _key + ".memmap" filename = filename / _key os.makedirs(filename.parent, exist_ok=True) mmap = MemoryMappedTensor.from_tensor(t, filename=filename) td[key] = mmap td.set_non_tensor("dataset_id", "/".join(path.parts[-3:-1])) td.memmap_(path, copy_existing=False) @staticmethod def _process_name(name): if name.endswith("_ckpt"): name = name[:-5] if "store" in name: key = ("data", name.split("_")[1]) else: key = (name,) if key[-1] == "terminal": key = (*key[:-1], "terminated") return key @classmethod def _list_runs(cls, download_path, gz_files=None) -> dict: path = download_path if gz_files is None: gz_files = [] for root, _, files in os.walk(path): for file in files: if file.endswith(".gz"): gz_files.append(os.path.join(root, file)) runs = defaultdict(list) for file in gz_files: filename = Path(file).parts[-1] name, episode, extension = str(filename).split(".") episode = int(episode) runs[episode].append(file) return dict(sorted(runs.items(), key=lambda x: x[0]))
[docs] def preprocess( self, fn: Callable[[TensorDictBase], TensorDictBase], dim: int = 0, num_workers: int | None = None, *, chunksize: int | None = None, num_chunks: int | None = None, pool: mp.Pool | None = None, generator: torch.Generator | None = None, max_tasks_per_child: int | None = None, worker_threads: int = 1, index_with_generator: bool = False, pbar: bool = False, mp_start_method: str | None = None, dest: str | Path, num_frames: int | None = None, ): # Copy data to a tensordict with tempfile.TemporaryDirectory() as tmpdir: first_item = self[0] metadata = first_item.pop("metadata") mmap = fn(first_item) if num_frames is None: num_frames = len(self) mmap = mmap.expand(num_frames, *first_item.shape) mmap = mmap.memmap_like(tmpdir, num_threads=32) with mmap.unlock_(): mmap["_indices"] = torch.arange(mmap.shape[0]) mmap.memmap_(tmpdir, num_threads=32) def func(mmap: TensorDictBase): idx = mmap["_indices"] orig = self[idx].exclude("metadata") orig = fn(orig) mmap.update(orig, inplace=True) return if dim != 0: raise RuntimeError("dim != 0 is not supported.") mmap.map( fn=CloudpickleWrapper(func), dim=dim, num_workers=num_workers, chunksize=chunksize, num_chunks=num_chunks, pool=pool, generator=generator, max_tasks_per_child=max_tasks_per_child, worker_threads=worker_threads, index_with_generator=index_with_generator, mp_start_method=mp_start_method, pbar=pbar, ) with mmap.unlock_(): return TensorStorage(mmap.set("metadata", metadata))
class _AtariStorage(Storage): def __init__(self, path): self.path = Path(path) def get_folders(path): return [ name for name in os.listdir(path) if os.path.isdir(os.path.join(path, name)) ] # Usage self.splits = [] folders = get_folders(path) for folder in folders: self.splits.append(int(Path(folder).parts[-1])) self.splits = sorted(self.splits) self._split_tds = [] frames_per_split = {} for split in self.splits: path = self.path / str(split) self._split_tds.append(self._load_split(path)) # take away 1 because we padded with 1 empty val frames_per_split[split] = ( self._split_tds[-1].get(("data", "observation")).shape[0] - 1 ) frames_per_split = torch.tensor( [[split, length] for (split, length) in frames_per_split.items()] ) frames_per_split[:, 1] = frames_per_split[:, 1].cumsum(0) self.frames_per_split = torch.cat( # [torch.tensor([[-1, 0]]), frames_per_split], 0 [torch.tensor([[-1, 0]]), frames_per_split], 0, ) # retrieve episodes self.episodes = torch.cumsum( torch.cat( [td.get(("data", "next", "terminated")) for td in self._split_tds], 0 ), 0, ) super().__init__(max_size=len(self)) def __len__(self): return self.frames_per_split[-1, 1].item() def _read_from_splits(self, item: int | torch.Tensor): # We need to allocate each item to its storage. # We don't assume each storage has the same size (too expensive to test) # so we keep a map of each storage cumulative length and retrieve the # storages one after the other. item = torch.as_tensor(item) if not item.ndim: is_int = True item = item.reshape(-1) else: is_int = False split = (item < self.frames_per_split[1:, 1].unsqueeze(1)) & ( item >= self.frames_per_split[:-1, 1].unsqueeze(1) ) # split_tmp, idx = split.squeeze().nonzero().unbind(-1) split_tmp, idx = split.nonzero().unbind(-1) split = split_tmp.squeeze() idx = idx.squeeze() if not is_int: split = torch.zeros_like(split_tmp) split[idx] = split_tmp split = self.frames_per_split[split + 1, 0] item = item - self.frames_per_split[split, 1] if is_int: item = item.squeeze() return self._proc_td(self._split_tds[split], item) unique_splits, split_inverse = torch.unique(split, return_inverse=True) unique_splits = unique_splits.tolist() out = [] for i, split in enumerate(unique_splits): _item = item[split_inverse == i] if split_inverse is not None else item out.append(self._proc_td(self._split_tds[split], _item)) return torch.cat(out, 0) def _load_split(self, path): return TensorDict.load_memmap(path) def _proc_td(self, td, index): td_data = td.get("data") obs_ = td_data.get(("observation"))[index + 1] done = td_data.get(("next", "terminated"))[index].squeeze(-1).bool() if done.ndim and done.any(): obs_ = torch.index_fill(obs_, 0, done.nonzero().squeeze(), 0) td_idx = td.empty() td_idx.set(("next", "observation"), obs_) non_tensor = td.exclude("data").to_dict() td_idx.update(td_data.apply(lambda x: x[index])) if isinstance(index, torch.Tensor) and index.ndim: td_idx.batch_size = [len(index)] td_idx.set_non_tensor("metadata", non_tensor) terminated = td_idx.get(("next", "terminated")) zterminated = torch.zeros_like(terminated) td_idx.set(("next", "done"), terminated.clone()) td_idx.set(("next", "truncated"), zterminated) td_idx.set("terminated", zterminated) td_idx.set("done", zterminated) td_idx.set("truncated", zterminated) return td_idx def get(self, index): if isinstance(index, int): return self._read_from_splits(index) if isinstance(index, tuple): if len(index) == 1: return self.get(index[0]) return self.get(index[0])[(Ellipsis, *index[1:])] if isinstance(index, torch.Tensor): if index.ndim <= 1: return self._read_from_splits(index) elif index.shape[1] == 1: index = index.squeeze(1) return self.get(index) else: raise RuntimeError("Only 1d tensors are accepted") # with ThreadPoolExecutor(16) as pool: # results = map(self.__getitem__, index.tolist()) # return torch.stack(list(results)) if isinstance(index, (range, list)): return self[torch.tensor(index)] if isinstance(index, slice): start = index.start if index.start is not None else 0 stop = index.stop if index.stop is not None else len(self) step = index.step if index.step is not None else 1 return self.get(torch.arange(start, stop, step)) return self[torch.arange(len(self))[index]]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources