Source code for torchrl.data.datasets.atari_dqn
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import functools
import gzip
import io
import json
import os
import shutil
import subprocess
import tempfile
from collections import defaultdict
from pathlib import Path
from typing import Callable
import numpy as np
import torch
from tensordict import MemoryMappedTensor, TensorDict, TensorDictBase
from torch import multiprocessing as mp
from torchrl._utils import logger as torchrl_logger
from torchrl.data.datasets.common import BaseDatasetExperienceReplay
from torchrl.data.replay_buffers.samplers import (
SamplerWithoutReplacement,
SliceSampler,
SliceSamplerWithoutReplacement,
)
from torchrl.data.replay_buffers.storages import Storage, TensorStorage
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter
from torchrl.data.utils import CloudpickleWrapper
from torchrl.envs.utils import _classproperty
[docs]class AtariDQNExperienceReplay(BaseDatasetExperienceReplay):
"""Atari DQN Experience replay class.
The Atari DQN dataset (https://offline-rl.github.io/) is a collection of 5 training
iterations of DQN over each of the Arari 2600 games for a total of 200 million frames.
The sub-sampling rate (frame-skip) is equal to 4, meaning that each game dataset
has 50 million steps in total.
The data format follows the :ref:`TED convention <TED-format>`. Since the dataset is quite heavy,
the data formatting is done on-line, at sampling time.
To make training more modular, we split the dataset in each of the Atari games
and separate each training round. Consequently, each dataset is presented as
a Storage of length 50x10^6 elements. Under the hood, this dataset is split
in 50 memory-mapped tensordicts of length 1 million each.
Args:
dataset_id (str): The dataset to be downloaded.
Must be part of ``AtariDQNExperienceReplay.available_datasets``.
batch_size (int): Batch-size used during sampling.
Can be overridden by `data.sample(batch_size)` if necessary.
Keyword Args:
root (Path or str, optional): The AtariDQN dataset root directory.
The actual dataset memory-mapped files will be saved under
`<root>/<dataset_id>`. If none is provided, it defaults to
``~/.cache/torchrl/atari`.
num_procs (int, optional): number of processes to launch for preprocessing.
Has no effect whenever the data is already downloaded. Defaults to 0
(no multiprocessing used).
download (bool or str, optional): Whether the dataset should be downloaded if
not found. Defaults to ``True``. Download can also be passed as ``"force"``,
in which case the downloaded data will be overwritten.
sampler (Sampler, optional): the sampler to be used. If none is provided
a default RandomSampler() will be used.
writer (Writer, optional): the writer to be used. If none is provided
a default :class:`~torchrl.data.replay_buffers.writers.ImmutableDatasetWriter` will be used.
collate_fn (callable, optional): merges a list of samples to form a
mini-batch of Tensor(s)/outputs. Used when using batched
loading from a map-style dataset.
pin_memory (bool): whether pin_memory() should be called on the rb
samples.
prefetch (int, optional): number of next batches to be prefetched
using multithreading.
transform (Transform, optional): Transform to be executed when sample() is called.
To chain transforms use the :class:`~torchrl.envs.transforms.transforms.Compose` class.
num_slices (int, optional): the number of slices to be sampled. The batch-size
must be greater or equal to the ``num_slices`` argument. Exclusive
with ``slice_len``. Defaults to ``None`` (no slice sampling).
The ``sampler`` arg will override this value.
slice_len (int, optional): the length of the slices to be sampled. The batch-size
must be greater or equal to the ``slice_len`` argument and divisible
by it. Exclusive with ``num_slices``. Defaults to ``None`` (no slice sampling).
The ``sampler`` arg will override this value.
strict_length (bool, optional): if ``False``, trajectories of length
shorter than `slice_len` (or `batch_size // num_slices`) will be
allowed to appear in the batch.
Be mindful that this can result in effective `batch_size` shorter
than the one asked for! Trajectories can be split using
:func:`torchrl.collectors.split_trajectories`. Defaults to ``True``.
The ``sampler`` arg will override this value.
replacement (bool, optional): if ``False``, sampling will occur without replacement.
The ``sampler`` arg will override this value.
mp_start_method (str, optional): the start method for multiprocessed
download. Defaults to ``"fork"``.
Attributes:
available_datasets: list of available datasets, formatted as `<game_name>/<run>`. Example:
`"Pong/5"`, `"Krull/2"`, ...
dataset_id (str): the name of the dataset.
episodes (torch.Tensor): a 1d tensor indicating to what run each of the
1M frames belongs. To be used with :class:`~torchrl.data.replay_buffers.SliceSampler`
to cheaply sample slices of episodes.
Examples:
>>> from torchrl.data.datasets import AtariDQNExperienceReplay
>>> dataset = AtariDQNExperienceReplay("Pong/5", batch_size=128)
>>> for data in dataset:
... print(data)
... break
TensorDict(
fields={
action: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int32, is_shared=False),
done: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False),
index: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int64, is_shared=False),
metadata: NonTensorData(
data={'invalid_range': MemoryMappedTensor([999998, 999999, 0, 1, 2]), 'add_count': MemoryMappedTensor(999999), 'dataset_id': 'Pong/5'}},
batch_size=torch.Size([128]),
device=None,
is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False),
observation: Tensor(shape=torch.Size([128, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),
reward: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False),
truncated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False)},
batch_size=torch.Size([128]),
device=None,
is_shared=False),
observation: Tensor(shape=torch.Size([128, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),
terminated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False),
truncated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False)},
batch_size=torch.Size([128]),
device=None,
is_shared=False)
.. warning::
Atari-DQN does not provide the next observation after a termination signal.
In other words, there is no way to obtain the ``("next", "observation")`` state
when ``("next", "done")`` is ``True``. This value is filled with 0s but should
not be used in practice. If TorchRL's value estimators (:class:`~torchrl.objectives.values.ValueEstimator`)
are used, this should not be an issue.
.. note::
Because the construction of the sampler for episode sampling is slightly
convoluted, we made it easy for users to pass the arguments of the
:class:`~torchrl.data.replay_buffers.SliceSampler` directly to the
``AtariDQNExperienceReplay`` dataset: any of the ``num_slices`` or
``slice_len`` arguments will make the sampler an instance of
:class:`~torchrl.data.replay_buffers.SliceSampler`. The ``strict_length``
can also be passed.
>>> from torchrl.data.datasets import AtariDQNExperienceReplay
>>> from torchrl.data.replay_buffers import SliceSampler
>>> dataset = AtariDQNExperienceReplay("Pong/5", batch_size=128, slice_len=64)
>>> for data in dataset:
... print(data)
... print(data.get("index")) # indices are in 4 groups of consecutive values
... break
TensorDict(
fields={
action: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int32, is_shared=False),
done: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False),
index: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int64, is_shared=False),
metadata: NonTensorData(
data={'invalid_range': MemoryMappedTensor([999998, 999999, 0, 1, 2]), 'add_count': MemoryMappedTensor(999999), 'dataset_id': 'Pong/5'}},
batch_size=torch.Size([128]),
device=None,
is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([128, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([128, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),
reward: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([128, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([128, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([128]),
device=None,
is_shared=False),
observation: Tensor(shape=torch.Size([128, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),
terminated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False),
truncated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False)},
batch_size=torch.Size([128]),
device=None,
is_shared=False)
tensor([2657628, 2657629, 2657630, 2657631, 2657632, 2657633, 2657634, 2657635,
2657636, 2657637, 2657638, 2657639, 2657640, 2657641, 2657642, 2657643,
2657644, 2657645, 2657646, 2657647, 2657648, 2657649, 2657650, 2657651,
2657652, 2657653, 2657654, 2657655, 2657656, 2657657, 2657658, 2657659,
2657660, 2657661, 2657662, 2657663, 2657664, 2657665, 2657666, 2657667,
2657668, 2657669, 2657670, 2657671, 2657672, 2657673, 2657674, 2657675,
2657676, 2657677, 2657678, 2657679, 2657680, 2657681, 2657682, 2657683,
2657684, 2657685, 2657686, 2657687, 2657688, 2657689, 2657690, 2657691,
1995687, 1995688, 1995689, 1995690, 1995691, 1995692, 1995693, 1995694,
1995695, 1995696, 1995697, 1995698, 1995699, 1995700, 1995701, 1995702,
1995703, 1995704, 1995705, 1995706, 1995707, 1995708, 1995709, 1995710,
1995711, 1995712, 1995713, 1995714, 1995715, 1995716, 1995717, 1995718,
1995719, 1995720, 1995721, 1995722, 1995723, 1995724, 1995725, 1995726,
1995727, 1995728, 1995729, 1995730, 1995731, 1995732, 1995733, 1995734,
1995735, 1995736, 1995737, 1995738, 1995739, 1995740, 1995741, 1995742,
1995743, 1995744, 1995745, 1995746, 1995747, 1995748, 1995749, 1995750])
.. note::
As always, datasets should be composed using :class:`~torchrl.data.replay_buffers.ReplayBufferEnsemble`:
>>> from torchrl.data.datasets import AtariDQNExperienceReplay
>>> from torchrl.data.replay_buffers import ReplayBufferEnsemble
>>> # we change this parameter for quick experimentation, in practice it should be left untouched
>>> AtariDQNExperienceReplay._max_runs = 2
>>> dataset_asterix = AtariDQNExperienceReplay("Asterix/5", batch_size=128, slice_len=64, num_procs=4)
>>> dataset_pong = AtariDQNExperienceReplay("Pong/5", batch_size=128, slice_len=64, num_procs=4)
>>> dataset = ReplayBufferEnsemble(dataset_pong, dataset_asterix, batch_size=128, sample_from_all=True)
>>> sample = dataset.sample()
>>> print("first sample, Asterix", sample[0])
first sample, Asterix TensorDict(
fields={
action: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int32, is_shared=False),
done: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False),
index: TensorDict(
fields={
buffer_ids: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False),
index: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([64]),
device=None,
is_shared=False),
metadata: NonTensorData(
data={'invalid_range': MemoryMappedTensor([999998, 999999, 0, 1, 2]), 'add_count': MemoryMappedTensor(999999), 'dataset_id': 'Pong/5'},
batch_size=torch.Size([64]),
device=None,
is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),
reward: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([64]),
device=None,
is_shared=False),
observation: Tensor(shape=torch.Size([64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),
terminated: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False),
truncated: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False)},
batch_size=torch.Size([64]),
device=None,
is_shared=False)
>>> print("second sample, Pong", sample[1])
second sample, Pong TensorDict(
fields={
action: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int32, is_shared=False),
done: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False),
index: TensorDict(
fields={
buffer_ids: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False),
index: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([64]),
device=None,
is_shared=False),
metadata: NonTensorData(
data={'invalid_range': MemoryMappedTensor([999998, 999999, 0, 1, 2]), 'add_count': MemoryMappedTensor(999999), 'dataset_id': 'Asterix/5'},
batch_size=torch.Size([64]),
device=None,
is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),
reward: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([64]),
device=None,
is_shared=False),
observation: Tensor(shape=torch.Size([64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),
terminated: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False),
truncated: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False)},
batch_size=torch.Size([64]),
device=None,
is_shared=False)
>>> print("Aggregate (metadata hidden)", sample)
Aggregate (metadata hidden) LazyStackedTensorDict(
fields={
action: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.int32, is_shared=False),
done: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.uint8, is_shared=False),
index: LazyStackedTensorDict(
fields={
buffer_ids: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.int64, is_shared=False),
index: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.int64, is_shared=False)},
exclusive_fields={
},
batch_size=torch.Size([2, 64]),
device=None,
is_shared=False,
stack_dim=0),
metadata: LazyStackedTensorDict(
fields={
},
exclusive_fields={
},
batch_size=torch.Size([2, 64]),
device=None,
is_shared=False,
stack_dim=0),
next: LazyStackedTensorDict(
fields={
done: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([2, 64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),
reward: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
exclusive_fields={
},
batch_size=torch.Size([2, 64]),
device=None,
is_shared=False,
stack_dim=0),
observation: Tensor(shape=torch.Size([2, 64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),
terminated: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.uint8, is_shared=False),
truncated: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.uint8, is_shared=False)},
exclusive_fields={
},
batch_size=torch.Size([2, 64]),
device=None,
is_shared=False,
stack_dim=0)
"""
@_classproperty
def available_datasets(cls):
games = [
"AirRaid",
"Alien",
"Amidar",
"Assault",
"Asterix",
"Asteroids",
"Atlantis",
"BankHeist",
"BattleZone",
"BeamRider",
"Berzerk",
"Bowling",
"Boxing",
"Breakout",
"Carnival",
"Centipede",
"ChopperCommand",
"CrazyClimber",
"DemonAttack",
"DoubleDunk",
"ElevatorAction",
"Enduro",
"FishingDerby",
"Freeway",
"Frostbite",
"Gopher",
"Gravitar",
"Hero",
"IceHockey",
"Jamesbond",
"JourneyEscape",
"Kangaroo",
"Krull",
"KungFuMaster",
"MontezumaRevenge",
"MsPacman",
"NameThisGame",
"Phoenix",
"Pitfall",
"Pong",
"Pooyan",
"PrivateEye",
"Qbert",
"Riverraid",
"RoadRunner",
"Robotank",
"Seaquest",
"Skiing",
"Solaris",
"SpaceInvaders",
]
return ["/".join((game, str(loop))) for game in games for loop in range(1, 6)]
# If we want to keep track of the original atari files
tmpdir = None
# use _max_runs for debugging, avoids downloading the entire dataset
_max_runs = None
def __init__(
self,
dataset_id: str,
batch_size: int | None = None,
*,
root: str | Path | None = None,
download: bool | str = True,
sampler=None,
writer=None,
transform: "Transform" | None = None, # noqa: F821
num_procs: int = 0,
num_slices: int | None = None,
slice_len: int | None = None,
strict_len: bool = True,
replacement: bool = True,
mp_start_method: str = "fork",
**kwargs,
):
if dataset_id not in self.available_datasets:
raise ValueError(
"The dataseet_id is not part of the available datasets. The dataset should be named <game_name>/<run> "
"where <game_name> is one of the Atari 2600 games and the run is a number betweeen 1 and 5. "
"The full list of accepted dataset_ids is available under AtariDQNExperienceReplay.available_datasets."
)
self.dataset_id = dataset_id
from torchrl.data.datasets.utils import _get_root_dir
if root is None:
root = _get_root_dir("atari")
self.root = root
self.num_procs = num_procs
self.mp_start_method = mp_start_method
if download == "force" or (download and not self._is_downloaded):
try:
self._download_and_preproc()
except Exception:
# remove temporary data
if os.path.exists(self.dataset_path):
shutil.rmtree(self.dataset_path)
raise
if self._downloaded_and_preproc:
storage = TensorStorage(TensorDict.load_memmap(self.dataset_path))
else:
storage = _AtariStorage(self.dataset_path)
if writer is None:
writer = ImmutableDatasetWriter()
if sampler is None:
if num_slices is not None or slice_len is not None:
if not replacement:
sampler = SliceSamplerWithoutReplacement(
num_slices=num_slices,
slice_len=slice_len,
trajectories=storage.episodes,
)
else:
sampler = SliceSampler(
num_slices=num_slices,
slice_len=slice_len,
trajectories=storage.episodes,
cache_values=True,
)
elif not replacement:
sampler = SamplerWithoutReplacement()
super().__init__(
storage=storage,
batch_size=batch_size,
writer=writer,
sampler=sampler,
collate_fn=lambda x: x,
transform=transform,
**kwargs,
)
@property
def episodes(self):
return self._storage.episodes
@property
def root(self) -> Path:
return self._root
@root.setter
def root(self, value):
self._root = Path(value)
@property
def dataset_path(self) -> Path:
return self._root / self.dataset_id
@property
def _downloaded_and_preproc(self):
return os.path.exists(self.dataset_path / "meta.json")
@property
def _is_downloaded(self):
if os.path.exists(self.dataset_path / "meta.json"):
return True
if os.path.exists(self.dataset_path / "processed.json"):
with open(self.dataset_path / "processed.json", "r") as jsonfile:
return json.load(jsonfile).get("processed", False) == self._max_runs
return False
def _download_and_preproc(self):
torchrl_logger.info(
f"Downloading and preprocessing dataset {self.dataset_id} with {self.num_procs} processes. This may take a while..."
)
if os.path.exists(self.dataset_path):
shutil.rmtree(self.dataset_path)
with tempfile.TemporaryDirectory() as tempdir:
if self.tmpdir is not None:
tempdir = self.tmpdir
if not os.listdir(tempdir):
os.makedirs(tempdir, exist_ok=True)
# get the list of runs
command = f"gsutil -m ls -R gs://atari-replay-datasets/dqn/{self.dataset_id}/replay_logs"
output = subprocess.run(
command, shell=True, capture_output=True
) # , stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
files = [
file.decode("utf-8").replace("$", "\$") # noqa: W605
for file in output.stdout.splitlines()
if file.endswith(b".gz")
]
self.remote_gz_files = self._list_runs(None, files)
remote_gz_files = list(self.remote_gz_files)
if not len(remote_gz_files):
raise RuntimeError(
"Could not load the file list. Did you install gsutil?"
)
total_runs = remote_gz_files[-1]
if self.num_procs == 0:
for run, run_files in self.remote_gz_files.items():
self._download_and_proc_split(
run,
run_files,
tempdir=tempdir,
dataset_path=self.dataset_path,
total_episodes=total_runs,
max_runs=self._max_runs,
multithreaded=True,
)
else:
func = functools.partial(
self._download_and_proc_split,
tempdir=tempdir,
dataset_path=self.dataset_path,
total_episodes=total_runs,
max_runs=self._max_runs,
multithreaded=False,
)
args = [
(run, run_files)
for (run, run_files) in self.remote_gz_files.items()
]
ctx = mp.get_context(self.mp_start_method)
with ctx.Pool(self.num_procs) as pool:
pool.starmap(func, args)
with open(self.dataset_path / "processed.json", "w") as file:
# we save self._max_runs such that changing the number of runs to process
# forces the data to be re-downloaded
json.dump({"processed": self._max_runs}, file)
@classmethod
def _download_and_proc_split(
cls,
run,
run_files,
*,
tempdir,
dataset_path,
total_episodes,
max_runs,
multithreaded=True,
):
if (max_runs is not None) and (run >= max_runs):
return
tempdir = Path(tempdir)
os.makedirs(tempdir / str(run))
files_str = " ".join(run_files) # .decode("utf-8")
torchrl_logger.info(f"downloading {files_str}")
if multithreaded:
command = f"gsutil -m cp {files_str} {tempdir}/{run}"
else:
command = f"gsutil cp {files_str} {tempdir}/{run}"
subprocess.run(
command, shell=True
) # , stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
local_gz_files = cls._list_runs(tempdir / str(run))
# we iterate over the dict but this one has length 1
for run in local_gz_files:
path = dataset_path / str(run)
try:
cls._preproc_run(path, local_gz_files, run)
except Exception:
shutil.rmtree(path)
raise
shutil.rmtree(tempdir / str(run))
torchrl_logger.info(f"Concluded run {run} out of {total_episodes}")
@classmethod
def _preproc_run(cls, path, gz_files, run):
files = gz_files[run]
td = TensorDict({}, [])
path = Path(path)
for file in files:
name = str(Path(file).parts[-1]).split(".")[0]
with gzip.GzipFile(file, mode="rb") as f:
file_content = f.read()
file_content = io.BytesIO(file_content)
file_content = np.load(file_content)
t = torch.as_tensor(file_content)
# Create the memmap file
key = cls._process_name(name)
if key == ("data", "observation"):
shape = t.shape
shape = [shape[0] + 1] + list(shape[1:])
filename = path / "data" / "observation.memmap"
os.makedirs(filename.parent, exist_ok=True)
mmap = MemoryMappedTensor.empty(shape, dtype=t.dtype, filename=filename)
mmap[:-1].copy_(t)
td[key] = mmap
# td["data", "next", key[1:]] = mmap[1:]
else:
if key in (
("data", "reward"),
("data", "done"),
("data", "terminated"),
):
filename = path / "data" / "next" / (key[-1] + ".memmap")
os.makedirs(filename.parent, exist_ok=True)
mmap = MemoryMappedTensor.from_tensor(t, filename=filename)
td["data", "next", key[1:]] = mmap
else:
filename = path
for i, _key in enumerate(key):
if i == len(key) - 1:
_key = _key + ".memmap"
filename = filename / _key
os.makedirs(filename.parent, exist_ok=True)
mmap = MemoryMappedTensor.from_tensor(t, filename=filename)
td[key] = mmap
td.set_non_tensor("dataset_id", "/".join(path.parts[-3:-1]))
td.memmap_(path, copy_existing=False)
@staticmethod
def _process_name(name):
if name.endswith("_ckpt"):
name = name[:-5]
if "store" in name:
key = ("data", name.split("_")[1])
else:
key = (name,)
if key[-1] == "terminal":
key = (*key[:-1], "terminated")
return key
@classmethod
def _list_runs(cls, download_path, gz_files=None) -> dict:
path = download_path
if gz_files is None:
gz_files = []
for root, _, files in os.walk(path):
for file in files:
if file.endswith(".gz"):
gz_files.append(os.path.join(root, file))
runs = defaultdict(list)
for file in gz_files:
filename = Path(file).parts[-1]
name, episode, extension = str(filename).split(".")
episode = int(episode)
runs[episode].append(file)
return dict(sorted(runs.items(), key=lambda x: x[0]))
[docs] def preprocess(
self,
fn: Callable[[TensorDictBase], TensorDictBase],
dim: int = 0,
num_workers: int | None = None,
*,
chunksize: int | None = None,
num_chunks: int | None = None,
pool: mp.Pool | None = None,
generator: torch.Generator | None = None,
max_tasks_per_child: int | None = None,
worker_threads: int = 1,
index_with_generator: bool = False,
pbar: bool = False,
mp_start_method: str | None = None,
dest: str | Path,
num_frames: int | None = None,
):
# Copy data to a tensordict
with tempfile.TemporaryDirectory() as tmpdir:
first_item = self[0]
metadata = first_item.pop("metadata")
mmap = fn(first_item)
if num_frames is None:
num_frames = len(self)
mmap = mmap.expand(num_frames, *first_item.shape)
mmap = mmap.memmap_like(tmpdir, num_threads=32)
with mmap.unlock_():
mmap["_indices"] = torch.arange(mmap.shape[0])
mmap.memmap_(tmpdir, num_threads=32)
def func(mmap: TensorDictBase):
idx = mmap["_indices"]
orig = self[idx].exclude("metadata")
orig = fn(orig)
mmap.update(orig, inplace=True)
return
if dim != 0:
raise RuntimeError("dim != 0 is not supported.")
mmap.map(
fn=CloudpickleWrapper(func),
dim=dim,
num_workers=num_workers,
chunksize=chunksize,
num_chunks=num_chunks,
pool=pool,
generator=generator,
max_tasks_per_child=max_tasks_per_child,
worker_threads=worker_threads,
index_with_generator=index_with_generator,
mp_start_method=mp_start_method,
pbar=pbar,
)
with mmap.unlock_():
return TensorStorage(mmap.set("metadata", metadata))
class _AtariStorage(Storage):
def __init__(self, path):
self.path = Path(path)
def get_folders(path):
return [
name
for name in os.listdir(path)
if os.path.isdir(os.path.join(path, name))
]
# Usage
self.splits = []
folders = get_folders(path)
for folder in folders:
self.splits.append(int(Path(folder).parts[-1]))
self.splits = sorted(self.splits)
self._split_tds = []
frames_per_split = {}
for split in self.splits:
path = self.path / str(split)
self._split_tds.append(self._load_split(path))
# take away 1 because we padded with 1 empty val
frames_per_split[split] = (
self._split_tds[-1].get(("data", "observation")).shape[0] - 1
)
frames_per_split = torch.tensor(
[[split, length] for (split, length) in frames_per_split.items()]
)
frames_per_split[:, 1] = frames_per_split[:, 1].cumsum(0)
self.frames_per_split = torch.cat(
# [torch.tensor([[-1, 0]]), frames_per_split], 0
[torch.tensor([[-1, 0]]), frames_per_split],
0,
)
# retrieve episodes
self.episodes = torch.cumsum(
torch.cat(
[td.get(("data", "next", "terminated")) for td in self._split_tds], 0
),
0,
)
super().__init__(max_size=len(self))
def __len__(self):
return self.frames_per_split[-1, 1].item()
def _read_from_splits(self, item: int | torch.Tensor):
# We need to allocate each item to its storage.
# We don't assume each storage has the same size (too expensive to test)
# so we keep a map of each storage cumulative length and retrieve the
# storages one after the other.
item = torch.as_tensor(item)
if not item.ndim:
is_int = True
item = item.reshape(-1)
else:
is_int = False
split = (item < self.frames_per_split[1:, 1].unsqueeze(1)) & (
item >= self.frames_per_split[:-1, 1].unsqueeze(1)
)
# split_tmp, idx = split.squeeze().nonzero().unbind(-1)
split_tmp, idx = split.nonzero().unbind(-1)
split = split_tmp.squeeze()
idx = idx.squeeze()
if not is_int:
split = torch.zeros_like(split_tmp)
split[idx] = split_tmp
split = self.frames_per_split[split + 1, 0]
item = item - self.frames_per_split[split, 1]
if is_int:
item = item.squeeze()
return self._proc_td(self._split_tds[split], item)
unique_splits, split_inverse = torch.unique(split, return_inverse=True)
unique_splits = unique_splits.tolist()
out = []
for i, split in enumerate(unique_splits):
_item = item[split_inverse == i] if split_inverse is not None else item
out.append(self._proc_td(self._split_tds[split], _item))
return torch.cat(out, 0)
def _load_split(self, path):
return TensorDict.load_memmap(path)
def _proc_td(self, td, index):
td_data = td.get("data")
obs_ = td_data.get(("observation"))[index + 1]
done = td_data.get(("next", "terminated"))[index].squeeze(-1).bool()
if done.ndim and done.any():
obs_ = torch.index_fill(obs_, 0, done.nonzero().squeeze(), 0)
td_idx = td.empty()
td_idx.set(("next", "observation"), obs_)
non_tensor = td.exclude("data").to_dict()
td_idx.update(td_data.apply(lambda x: x[index]))
if isinstance(index, torch.Tensor) and index.ndim:
td_idx.batch_size = [len(index)]
td_idx.set_non_tensor("metadata", non_tensor)
terminated = td_idx.get(("next", "terminated"))
zterminated = torch.zeros_like(terminated)
td_idx.set(("next", "done"), terminated.clone())
td_idx.set(("next", "truncated"), zterminated)
td_idx.set("terminated", zterminated)
td_idx.set("done", zterminated)
td_idx.set("truncated", zterminated)
return td_idx
def get(self, index):
if isinstance(index, int):
return self._read_from_splits(index)
if isinstance(index, tuple):
if len(index) == 1:
return self.get(index[0])
return self.get(index[0])[(Ellipsis, *index[1:])]
if isinstance(index, torch.Tensor):
if index.ndim <= 1:
return self._read_from_splits(index)
elif index.shape[1] == 1:
index = index.squeeze(1)
return self.get(index)
else:
raise RuntimeError("Only 1d tensors are accepted")
# with ThreadPoolExecutor(16) as pool:
# results = map(self.__getitem__, index.tolist())
# return torch.stack(list(results))
if isinstance(index, (range, list)):
return self[torch.tensor(index)]
if isinstance(index, slice):
start = index.start if index.start is not None else 0
stop = index.stop if index.stop is not None else len(self)
step = index.step if index.step is not None else 1
return self.get(torch.arange(start, stop, step))
return self[torch.arange(len(self))[index]]