Shortcuts

Source code for torchrl.data.datasets.gen_dgrl

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import importlib.util
import os
import shutil
import tarfile
import tempfile
import typing as tp
from pathlib import Path

import numpy as np
import torch

from tensordict import TensorDict
from torchrl._utils import logger as torchrl_logger
from torchrl.data.datasets.common import BaseDatasetExperienceReplay
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.storages import TensorStorage
from torchrl.envs.utils import _classproperty

_has_tqdm = importlib.util.find_spec("tqdm", None) is not None
_has_requests = importlib.util.find_spec("requests", None) is not None


[docs]class GenDGRLExperienceReplay(BaseDatasetExperienceReplay): """Gen-DGRL Experience Replay dataset. This dataset accompanies the paper "The Generalization Gap in Offline Reinforcement Learning". Arxiv: https://arxiv.org/abs/2312.05742 GitHub: https://github.com/facebookresearch/gen_dgrl The data format follows the :ref:`TED convention <TED-format>`. This class gives you access to the ProcGen dataset. Each `dataset_id` registered in `GenDGRLExperienceReplay.available_datasets` consists in a particular task (`"bigfish"`, `"bossfight"`, ...) separated from a category (`"1M_E"`, `"1M_S"`, ...) by a comma (`"bigfish-1M_E"`, ...). During download and preparation, the data is downloaded as .tar files, where each trajectory is stored independently in a .npy file. Each of these files is extracted, written in a contiguous mmap tensor, and then cleared. This process can take several minutes per dataset. On a cluster, it is advisable to first run the download and preprocessing separately on different workers or processes for different datasets, and launch the training script in a second time. Args: dataset_id (str): the dataset to be downloaded. Must be part of :attr:`GenDGRLExperienceReplay.available_datasets`. batch_size (int, optional): Batch-size used during sampling. Can be overridden by `data.sample(batch_size)` if necessary. Keyword Args: root (Path or str, optional): The :class:`~torchrl.data.datasets.GenDGRLExperienceReplay` dataset root directory. The actual dataset memory-mapped files will be saved under `<root>/<dataset_id>`. If none is provided, it defaults to ``~/.cache/torchrl/gen_dgrl`. download (bool or str, optional): Whether the dataset should be downloaded if not found. Defaults to ``True``. Download can also be passed as ``"force"``, in which case the downloaded data will be overwritten. sampler (Sampler, optional): the sampler to be used. If none is provided a default RandomSampler() will be used. writer (Writer, optional): the writer to be used. If none is provided a default RoundRobinWriter() will be used. collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s)/outputs. Used when using batched loading from a map-style dataset. pin_memory (bool): whether pin_memory() should be called on the rb samples. prefetch (int, optional): number of next batches to be prefetched using multithreading. transform (Transform, optional): Transform to be executed when sample() is called. To chain transforms use the :obj:`Compose` class. Attributes: available_datasets: a list of accepted entries to be downloaded. These names correspond to the directory path in the huggingface dataset repository. If possible, the list will be dynamically retrieved from huggingface. If no internet connection is available, it a cached version will be used. Examples: >>> import torch >>> torch.manual_seed(0) >>> from torchrl.data.datasets import GenDGRLExperienceReplay >>> d = GenDGRLExperienceReplay("bigfish-1M_E", batch_size=32) >>> for batch in d: ... break >>> print(batch) """ BASE_URL = "https://dl.fbaipublicfiles.com/DGRL/Procgen/Datasets/Compressed" # number of files extracted at a time _PROCESS_NPY_BATCH = 32 split_trajs: bool = False @_classproperty def available_datasets(cls): datasets = [ "bigfish", "bossfight", "caveflyer", "chaser", "climber", "coinrun", "dodgeball", "fruitbot", "heist", "jumper", "leaper", "maze", "miner", "ninja", "plunder", "starpilot", ] categories = [ "1M_E", "1M_S", "10M", "25M", "level_1_E", "level_1_S", "level_40_E", "level_40_S", ] return ["-".join((ds, cat)) for cat in categories for ds in datasets] def __init__( self, dataset_id: str, batch_size: int = None, *, download: bool = True, root: str | None = None, **kwargs, ): self.dataset_id = dataset_id try: dataset, category = dataset_id.split("-") except Exception: category = dataset_id dataset = None self._dataset_name = dataset self._category_name = category if root is None: root = _get_root_dir("gen_dgrl") os.makedirs(root, exist_ok=True) self.root = root if download == "force" or (download and not self._is_downloaded()): if download == "force" and os.path.exists(self.data_path_root): shutil.rmtree(self.data_path_root) storage = TensorStorage(self._download_and_preproc()) else: storage = TensorStorage(TensorDict.load_memmap(self.data_path_root)) super().__init__(storage=storage, batch_size=batch_size, **kwargs) @property def data_path(self): if self.split_trajs: return Path(self.root) / (self.dataset_id + "_split") return self.data_path_root @property def data_path_root(self): return Path(self.root) / self.dataset_id def _is_downloaded(self): return os.path.exists(self.data_path_root) def _download_and_preproc(self): dataset, category = self._dataset_name, self._category_name link = self._build_urls_with_category_name(dataset, category) data_link = (category, self._fetch_file_name_from_link(link), link) with tempfile.TemporaryDirectory() as tmpdir: self._download_category_file( tmpdir, skip_downloaded_files=True, link=data_link ) return self._unpack_category_file( tmpdir, clear_archive=True, link=data_link, category_name=category ) @classmethod def _build_urls_with_category_name( cls, dataset, category_name: str ) -> tp.List[str]: path = [cls.BASE_URL, cls._convert_category_name(category_name)] path += [f"{dataset}.tar.xz"] return os.path.join(*path) @staticmethod def _convert_category_name(category_name: str) -> str: if category_name == "1M_E": return "1M/level_200/expert" elif category_name == "1M_S": return "1M/level_200/suboptimal" elif category_name == "10M": return "10M/level_200/expert" elif category_name == "25M": return "25M/level_200/expert" elif category_name == "level_1_S": return "100k/level_1/suboptimal" elif category_name == "level_40_S": return "100k/level_40/suboptimal" elif category_name == "level_1_E": return "100k/level_1/expert" elif category_name == "level_40_E": return "100k/level_40/expert" else: raise ValueError(f"Unrecognized category name {category_name}!") @staticmethod def _fetch_file_name_from_link(url: str) -> str: return os.path.split(url)[-1] @classmethod def _get_category_len(cls, category_name): if "1M" in category_name: return 1_000_000 if "10M" in category_name: return 10_000_000 if "25M" in category_name: return 25_000_000 return 100_000 def _unpack_category_file( self, download_folder: str, clear_archive: bool, category_name, link: str, batch=None, ): if batch is None: batch = self._PROCESS_NPY_BATCH _, file_name, _ = link file_path = os.path.join(download_folder, file_name) torchrl_logger.info( f"Unpacking dataset file {file_path} ({file_name}) to {download_folder}." ) idx = 0 td_memmap = None dataset_len = self._get_category_len(category_name) if _has_tqdm: from tqdm import tqdm pbar = tqdm(total=dataset_len) else: pbar = None mode = "r:xz" if str(file_path).endswith("xz") else "r" full = False with tarfile.open(file_path, mode) as tar: members = list(tar.getmembers()) for i in range(0, len(members), batch): if full: break submembers = [ member for member in members[i : i + batch] if member.isfile() ] for member in submembers: if pbar is not None: pbar.set_description(member.name) npybuffer = tar.extractfile(member=member) # npyfile = Path(download_folder) / member.name npfile = np.load(npybuffer, allow_pickle=True) td = TensorDict.from_dict(npfile.tolist()) td.set("observations", td.get("observations").to(torch.uint8)) td.set(("next", "observation"), td.get("observations")[1:]) td.set("observations", td.get("observations")[:-1]) td.rename_key_("observations", "observation") td.rename_key_("dones", ("next", "done")) td.rename_key_("actions", "action") td.rename_key_("rewards", ("next", "reward")) td.set( ("next", "done"), td.get(("next", "done")).bool().unsqueeze(-1) ) td.set( ("next", "truncated"), torch.zeros_like(td.get(("next", "done"))), ) td.set(("next", "terminated"), td.get(("next", "done"))) td.set( "terminated", torch.zeros_like(td.get(("next", "terminated"))) ) td.set("done", torch.zeros_like(td.get(("next", "done")))) td.set("truncated", torch.zeros_like(td.get(("next", "truncated")))) td.batch_size = td.get("observation").shape[:1] if td_memmap is None: td_memmap = ( td[0] .expand(dataset_len) .memmap_like(self.data_path_root, num_threads=16) ) idx_end = idx + td.shape[0] idx_end = min(idx_end, td_memmap.shape[0]) if pbar is not None: pbar.update(td.shape[0]) length = idx_end - idx if length > 0: if length != td.shape[0]: td_memmap[idx:idx_end] = td[:length] else: td_memmap[idx:idx_end] = td else: full = True idx = idx_end if clear_archive: os.remove(file_path) return td_memmap @classmethod def _download_category_file( cls, download_folder: str, skip_downloaded_files: bool, link: str, ): _, file_name, url = link file_path = os.path.join(download_folder, file_name) if skip_downloaded_files and os.path.isfile(file_path): torchrl_logger.info(f"Skipping {file_path}, already downloaded!") return file_name, True in_progress_folder = os.path.join(download_folder, "_in_progress") os.makedirs(in_progress_folder, exist_ok=True) in_progress_file_path = os.path.join(in_progress_folder, file_name) torchrl_logger.info( f"Downloading dataset file {file_name} ({url}) to {in_progress_file_path}." ) cls._download_with_progress_bar(url, in_progress_file_path) os.rename(in_progress_file_path, file_path) return file_name, True @classmethod def _download_with_progress_bar(cls, url: str, file_path: str): # taken from https://stackoverflow.com/a/62113293/986477 if not _has_requests: raise ImportError( "The requests package is required for Gen-DGRL dataset download." ) import requests resp = requests.get(url, stream=True) total = int(resp.headers.get("content-length", 0)) if _has_tqdm: from tqdm import tqdm pbar = tqdm( desc=file_path, total=total, unit="iB", unit_scale=True, unit_divisor=1024, ) else: pbar = None with open(file_path, "wb") as file: for data in resp.iter_content(chunk_size=1024): size = file.write(data) if pbar is not None: pbar.update(size)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources