Source code for torchrl.data.datasets.openx
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import importlib.util
import io
import json
import os
import shutil
import tempfile
from pathlib import Path
from typing import Any, Callable, Dict, Tuple
import torch
from tensordict import make_tensordict, NonTensorData, pad, TensorDict
from tensordict.utils import _is_non_tensor
from torchrl.data.datasets.common import BaseDatasetExperienceReplay
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import (
Sampler,
SliceSampler,
SliceSamplerWithoutReplacement,
)
from torchrl.data.replay_buffers.storages import _collate_id, Storage, TensorStorage
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
_has_datasets = importlib.util.find_spec("datasets", None) is not None
_has_tv = importlib.util.find_spec("torchvision", None) is not None
[docs]class OpenXExperienceReplay(BaseDatasetExperienceReplay):
"""Open X-Embodiment datasets experience replay.
The Open X-Embodiment Dataset contains 1M+ real robot trajectories
spanning 22 robot embodiments, collected through a collaboration between
21 institutions, demonstrating 527 skills (160266 tasks).
Website: https://robotics-transformer-x.github.io/
GitHub: https://github.com/google-deepmind/open_x_embodiment
Paper: https://arxiv.org/abs/2310.08864
The data format follows the :ref:`TED convention <TED-format>`.
.. note::
Non-tensor data will be written in the tensordict data using the
:class:`~tensordict.tensorclass.NonTensorData` primitive.
For instance, the `language_instruction` field in the data will
be stored in `data.get_non_tensor("language_instruction")` (or equivalently
`data.get("language_instruction").data`). See the documentation of this
class for more information on how to interact with non-tensor data
stored in a :class:`~tensordict.TensorDict`.
Args:
dataset_id (str): The dataset to be downloaded.
Must be part of ``OpenXExperienceReplay.available_datasets``.
batch_size (int): Batch-size used during sampling.
Can be overridden by `data.sample(batch_size)` if necessary.
See ``num_slices`` and ``slice_len`` keyword arguments for a refined
sampling strategy.
If the ``batch_size`` is ``None`` (default), iterating over the
dataset will deliver trajectories one at a time *whereas* calling
:meth:`~.sample` will *still* require a batch-size to be provided.
Keyword Args:
shuffle (bool, optional): if ``True``, trajectories are delivered in a
random order when the dataset is iterated over.
If ``False``, the dataset is iterated over in the pre-defined order.
.. warning::
shuffle=False will also impact the sampling. We advice users to
create a copy of the dataset where the ``shuffle`` attribute of the
sampler is set to ``False`` if they wish to enjoy the two different
behaviours (shuffled and not) within the same code base.
num_slices (int, optional): the number of slices in a batch. This
corresponds to the number of trajectories present in a batch.
Once collected, the batch is presented as a concatenation of
sub-trajectories that can be recovered through `batch.reshape(num_slices, -1)`.
The `batch_size` must be divisible by `num_slices` if provided.
This argument is exclusive with ``slice_len``.
If the ``num_slices`` argument equates the ``batch_size``, each sample
will belong to a different trajectory.
If neither ``slice_len`` nor ``num_slice`` are provided:
whenever a trajectory has a length shorter than the
batch-size, a contiguous slice of it of length `batch_size` will be
sampled. If the trajectory length is insufficient, an exception will
be raised unless `pad` is not `None`.
slice_len (int, optional): the length of slices in a batch. This
corresponds to the length of trajectories present in a batch.
Once collected, the batch is presented as a concatenation of
sub-trajectories that can be recovered through `batch.reshape(-1, slice_len)`.
The `batch_size` must be divisible by `slice_len` if provided.
This argument is exclusive with ``num_slice``.
If the ``slice_len`` argument equates ``1``, each sample
will belong to a different trajectory.
If neither ``slice_len`` nor ``num_slice`` are provided:
whenever a trajectory has a length shorter than the
batch-size, a contiguous slice of it of length `batch_size` will be
sampled. If the trajectory length is insufficient, an exception will
be raised unless `pad` is not `None`.
.. note::
The ``slice_len`` (but not ``num_slices``) can be used when
iterating over a dataset without passing a batch-size in the,
constructor. In these cases, a random sub-sequence of the
trajectory will be chosen.
replacement (bool, optional): if ``False``, sampling will be done
without replacement. Defaults to ``True`` for downloaded datasets,
``False`` for streamed datasets.
pad (bool, float or None): if ``True``, trajectories of insufficient length
given the `slice_len` or `num_slices` arguments will be padded with
0s. If another value is provided, it will be used for padding. If
``False`` or ``None`` (default) any encounter with a trajectory of
insufficient length will raise an exception.
root (Path or str, optional): The OpenX dataset root directory.
The actual dataset memory-mapped files will be saved under
`<root>/<dataset_id>`. If none is provided, it defaults to
``~/.cache/torchrl/openx`.
streaming (bool, optional): if ``True``, the data won't be downloaded but
read from a stream instead.
.. note:: The formatting of the data __will change__ when `download=True`
compared to `streaming=True`. If the data is downloaded and
the sampler is left untouched (ie, `num_slices=None`, `slice_len=None`
and `sampler=None`, transitions will be sampled randomly from
the dataset. This isn't possible at a reasonable cost with
`streaming=True`: in this case, trajectories will be sampled
one at a time and delivered as such (with cropping to comply with
the batch-size etc). The behaviour of the two modalities is
much more similar when `num_slices` and `slice_len` are specified,
as in these cases, views of sub-episodes will be returned in both
cases.
download (bool or str, optional): Whether the dataset should be downloaded if
not found. Defaults to ``True``. Download can also be passed as "force",
in which case the downloaded data will be overwritten.
sampler (Sampler, optional): the sampler to be used. If none is provided
a default RandomSampler() will be used.
writer (Writer, optional): the writer to be used. If none is provided
a default :class:`~torchrl.data.replay_buffers.writers.ImmutableDatasetWriter` will be used.
collate_fn (callable, optional): merges a list of samples to form a
mini-batch of Tensor(s)/outputs. Used when using batched
loading from a map-style dataset.
pin_memory (bool): whether pin_memory() should be called on the rb
samples.
prefetch (int, optional): number of next batches to be prefetched
using multithreading.
transform (Transform, optional): Transform to be executed when sample() is called.
To chain transforms use the :class:`~torchrl.envs.transforms.transforms.Compose` class.
split_trajs (bool, optional): if ``True``, the trajectories will be split
along the first dimension and padded to have a matching shape.
To split the trajectories, the ``"done"`` signal will be used, which
is recovered via ``done = truncated | terminated``. In other words,
it is assumed that any ``truncated`` or ``terminated`` signal is
equivalent to the end of a trajectory.
Defaults to ``False``.
strict_length (bool, optional): if ``False``, trajectories of length
shorter than `slice_len` (or `batch_size // num_slices`) will be
allowed to appear in the batch.
Be mindful that this can result in effective `batch_size` shorter
than the one asked for! Trajectories can be split using
:func:`torchrl.collectors.split_trajectories`. Defaults to ``True``.
Examples:
>>> from torchrl.data.datasets import OpenXExperienceReplay
>>> import tempfile
>>> # Download the data, and sample 128 elements in each batch out of two trajectories
>>> num_slices = 2
>>> with tempfile.TemporaryDirectory() as root:
... dataset = OpenXExperienceReplay("cmu_stretch", batch_size=128,
... num_slices=num_slices, download=True, streaming=False,
... root=root,
... )
... for batch in dataset:
... print(batch.reshape(num_slices, -1))
... break
TensorDict(
fields={
action: Tensor(shape=torch.Size([2, 64, 8]), device=cpu, dtype=torch.float64, is_shared=False),
discount: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.float32, is_shared=False),
done: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
episode: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.int32, is_shared=False),
index: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.int64, is_shared=False),
is_init: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.bool, is_shared=False),
language_embedding: Tensor(shape=torch.Size([2, 64, 512]), device=cpu, dtype=torch.float64, is_shared=False),
language_instruction: NonTensorData(
data='lift open green garbage can lid',
batch_size=torch.Size([2, 64]),
device=cpu,
is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: TensorDict(
fields={
image: Tensor(shape=torch.Size([2, 64, 3, 128, 128]), device=cpu, dtype=torch.uint8, is_shared=False),
state: Tensor(shape=torch.Size([2, 64, 4]), device=cpu, dtype=torch.float64, is_shared=False)},
batch_size=torch.Size([2, 64]),
device=cpu,
is_shared=False),
reward: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([2, 64]),
device=cpu,
is_shared=False),
observation: TensorDict(
fields={
image: Tensor(shape=torch.Size([2, 64, 3, 128, 128]), device=cpu, dtype=torch.uint8, is_shared=False),
state: Tensor(shape=torch.Size([2, 64, 4]), device=cpu, dtype=torch.float64, is_shared=False)},
batch_size=torch.Size([2, 64]),
device=cpu,
is_shared=False),
terminated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([2, 64]),
device=cpu,
is_shared=False)
>>> # Read data from a stream. Deliver entire trajectories when iterating
>>> dataset = OpenXExperienceReplay("cmu_stretch",
... num_slices=num_slices, download=False, streaming=True)
>>> for data in dataset: # data does not have a consistent shape
... break
>>> # Define batch-size dynamically
>>> data = dataset.sample(128) # delivers 2 sub-trajectories of length 64
"""
available_datasets = [
"fractal20220817_data",
"kuka",
"bridge",
"taco_play",
"jaco_play",
"berkeley_cable_routing",
"roboturk",
"nyu_door_opening_surprising_effectiveness",
"viola",
"berkeley_autolab_ur5",
"toto",
"language_table",
"columbia_cairlab_pusht_real",
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds",
"nyu_rot_dataset_converted_externally_to_rlds",
"stanford_hydra_dataset_converted_externally_to_rlds",
"austin_buds_dataset_converted_externally_to_rlds",
"nyu_franka_play_dataset_converted_externally_to_rlds",
"maniskill_dataset_converted_externally_to_rlds",
"furniture_bench_dataset_converted_externally_to_rlds",
"cmu_franka_exploration_dataset_converted_externally_to_rlds",
"ucsd_kitchen_dataset_converted_externally_to_rlds",
"ucsd_pick_and_place_dataset_converted_externally_to_rlds",
"austin_sailor_dataset_converted_externally_to_rlds",
"austin_sirius_dataset_converted_externally_to_rlds",
"bc_z",
"usc_cloth_sim_converted_externally_to_rlds",
"utokyo_pr2_opening_fridge_converted_externally_to_rlds",
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds",
"utokyo_saytap_converted_externally_to_rlds",
"utokyo_xarm_pick_and_place_converted_externally_to_rlds",
"utokyo_xarm_bimanual_converted_externally_to_rlds",
"robo_net",
"berkeley_mvp_converted_externally_to_rlds",
"berkeley_rpt_converted_externally_to_rlds",
"kaist_nonprehensile_converted_externally_to_rlds",
"stanford_mask_vit_converted_externally_to_rlds",
"tokyo_u_lsmo_converted_externally_to_rlds",
"dlr_sara_pour_converted_externally_to_rlds",
"dlr_sara_grid_clamp_converted_externally_to_rlds",
"dlr_edan_shared_control_converted_externally_to_rlds",
"asu_table_top_converted_externally_to_rlds",
"stanford_robocook_converted_externally_to_rlds",
"eth_agent_affordances",
"imperialcollege_sawyer_wrist_cam",
"iamlab_cmu_pickup_insert_converted_externally_to_rlds",
"uiuc_d3field",
"utaustin_mutex",
"berkeley_fanuc_manipulation",
"cmu_playing_with_food",
"cmu_play_fusion",
"cmu_stretch",
"berkeley_gnm_recon",
"berkeley_gnm_cory_hall",
"berkeley_gnm_sac_son",
]
# some very high number that should be above all trajecory lengths in the dataset
_MAX_TRAJ_LEN = 1_000_000
def __init__(
self,
dataset_id,
batch_size: int | None = None,
*,
shuffle: bool = True,
num_slices: int | None = None,
slice_len: int | None = None,
pad: float | bool | None = None,
replacement: bool = None,
streaming: bool | None = None,
root: str | Path | None = None,
download: bool | None = None,
sampler: Sampler | None = None,
writer: Writer | None = None,
collate_fn: Callable | None = None,
pin_memory: bool = False,
prefetch: int | None = None,
transform: "torchrl.envs.Transform" | None = None, # noqa-F821
split_trajs: bool = False,
strict_length: bool = True,
):
if download is None and streaming is None:
download = False
streaming = True
elif download is None:
download = not streaming
elif streaming is None:
streaming = not download
self.download = download
self.streaming = streaming
self.dataset_id = dataset_id
self.split_trajs = split_trajs
self.shuffle = shuffle
self.num_slices = num_slices
self.slice_len = slice_len
self.pad = pad
self.strict_length = strict_length
if (self.num_slices is not None) and (self.slice_len is not None):
raise ValueError("num_slices or slice_len can be not None, but not both.")
if split_trajs:
raise NotImplementedError
if not streaming:
if replacement is None:
replacement = True
if pad is not None:
raise RuntimeError(
"the `pad` argument is to be used only with streaming datasets."
)
if root is None:
root = _get_root_dir("openx")
os.makedirs(root, exist_ok=True)
self.root = Path(root)
if self.download == "force" or (
self.download and not self._is_downloaded()
):
if download == "force" and os.path.exists(self.data_path_root):
shutil.rmtree(self.data_path_root)
storage = self._download_and_preproc()
else:
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
if num_slices is not None or slice_len is not None:
if sampler is not None:
raise ValueError(
"`num_slices` and `slice_len` are exclusive with the `sampler` argument."
)
if replacement:
if not self.shuffle:
raise RuntimeError(
"shuffle=False can only be used when replacement=False."
)
sampler = SliceSampler(
num_slices=num_slices,
slice_len=slice_len,
strict_length=strict_length,
)
else:
sampler = SliceSamplerWithoutReplacement(
num_slices=num_slices,
slice_len=slice_len,
strict_length=strict_length,
shuffle=self.shuffle,
)
else:
if replacement is True:
# replacement can be False or None
raise RuntimeError(
"replacement=True is not available with streamed datasets."
)
self.root = None
if download:
raise ValueError(
"download and streaming cannot be set to ``True`` concomitantly."
)
storage = _StreamingStorage(
dataset_id=dataset_id,
shuffle=self.shuffle,
num_slices=self.num_slices,
slice_len=self.slice_len,
pad=self.pad,
)
if sampler is None:
sampler = _StreamingSampler()
if writer is None:
writer = ImmutableDatasetWriter()
if collate_fn is None:
collate_fn = _collate_id
super().__init__(
storage=storage,
sampler=sampler,
writer=writer,
collate_fn=collate_fn,
pin_memory=pin_memory,
prefetch=prefetch,
batch_size=batch_size,
transform=transform,
)
def __iter__(self):
if self._batch_size is None:
# we can still iterate over the dataset
if isinstance(self._storage, _StreamingStorage):
yield from self._storage
elif self.slice_len is not None and self.num_slices is None:
try:
# truncate the trajs with slice_len
self._batch_size = self.slice_len
self.num_slices = 1
self.slice_len = None
yield from self
finally:
self.slice_len = self._batch_size
self._batch_size = None
self.num_slices = None
else:
# if we don't have a batch size but we know how many trajectories
# we want in each batch, we can build that on the fly.
# The only time we can do this is if num_slices is given but not
# slice_len.
num_slices = self.num_slices
if not num_slices:
num_slices = 1
sampler = SliceSamplerWithoutReplacement(
num_slices=num_slices,
strict_length=False,
shuffle=self.shuffle,
)
batch_size = self._MAX_TRAJ_LEN
yield from TensorDictReplayBuffer(
storage=self._storage,
sampler=sampler,
batch_size=batch_size,
transform=self._transform,
)
else:
yield from super().__iter__()
@property
def data_path(self):
if self.streaming:
return None
if self.split_trajs:
return Path(self.root) / (self.dataset_id + "_split")
return self.data_path_root
@property
def data_path_root(self):
if self.streaming:
return None
return self.root / self.dataset_id
def _is_downloaded(self):
return os.path.exists(self.data_path_root)
def _download_and_preproc(self):
if not _has_datasets:
raise ImportError(
f"the `datasets` library is required for the dataset {self.dataset_id}."
)
import datasets
with tempfile.TemporaryDirectory() as cache_dir:
dataset = datasets.load_dataset(
"jxu124/OpenX-Embodiment",
self.dataset_id,
streaming=False,
split="train",
cache_dir=cache_dir,
trust_remote_code=True,
)
# iterate over the dataset a first time to count elements
total_frames = 0
try:
import tqdm
_has_tqdm = True
pbar = tqdm.tqdm(dataset, desc="counting")
except ImportError:
_has_tqdm = False
pbar = dataset
for data in pbar:
if total_frames == 0:
for step in data["data.pickle"]["steps"]:
td = _make_tensordict_image_conv(step).zero_()
# format td: requires td to have a non-null batch_size
td = td.expand(2, *td.shape)
_format_data(td, 0)
td = td[0]
total_frames += len(data["data.pickle"]["steps"])
td_data = td.expand(total_frames)
def expand_non_tensor(x):
if isinstance(x, NonTensorData):
return x.maybe_to_stack()
return x
td_data = td_data._apply_nest(
expand_non_tensor,
is_leaf=lambda x: issubclass(x, torch.Tensor) or _is_non_tensor(x),
)
td_data = td_data.memmap_like(self.root / self.dataset_id)
if _has_tqdm:
pbar = tqdm.tqdm(dataset, desc="preproc", total=total_frames)
else:
pbar = dataset
idx0 = 0
idx1 = 0
episode = 0
for data in pbar:
current_ep = torch.stack(
[
_make_tensordict_image_conv(step)
for step in data["data.pickle"]["steps"]
]
).contiguous()
_format_data(current_ep, episode)
episode += 1
idx1 += len(current_ep)
td_data[idx0:idx1] = current_ep
idx0 = idx1
if _has_tqdm:
pbar.update(current_ep.shape[0])
return TensorStorage(td_data.lock_())
class _StreamingStorage(Storage):
SLICE_MISMATCH = "The batch_size {} must be divisible by num_slices {} or slice_len {} if provided."
def __init__(
self,
dataset_id: str,
repo: str = "jxu124/OpenX-Embodiment",
split="train",
base_path="data.pickle",
shuffle: bool = True,
truncate: bool = True,
num_slices=None,
slice_len=None,
pad=None,
):
self.shuffle = shuffle
self.dataset_id = dataset_id
self.repo = repo
self.split = split
self._init()
self.base_path = base_path
self.truncate = truncate
self.num_slices = num_slices
self.slice_len = slice_len
self.pad = pad
def _init(self):
if not _has_datasets:
raise ImportError(
f"the `datasets` library is required for the dataset {self.dataset_id}."
)
import datasets
dataset = datasets.load_dataset(
self.repo, self.dataset_id, streaming=True, split=self.split
)
if self.shuffle:
dataset = dataset.shuffle()
self.dataset = dataset
self.dataset_iter = iter(dataset)
def __iter__(self):
episode = 0
for data in self.dataset:
if self.base_path:
data = data[self.base_path]
data = torch.stack(
[_make_tensordict_image_conv(step) for step in data["steps"]]
).contiguous()
_format_data(data, episode)
if self.slice_len is not None:
yield _slice_data(data, slice_len=self.slice_len, pad_value=self.pad)
else:
yield data
def get(self, index: range | torch.Tensor) -> Any:
if not isinstance(index, range):
if (index[1:] != index[:-1] + 1).any():
# we use a range to indicate how much data we want
raise RuntimeError("iterable datasets do not support indexing.")
index = range(index.shape[0])
total = 0
data_list = []
episode = 0
batch_size = index.stop
if self.num_slices is not None:
if batch_size % self.num_slices != 0:
raise ValueError(
self.SLICE_MISMATCH.format(
batch_size, self.num_slices, self.slice_len
)
)
num_slices = self.num_slices
slice_len = batch_size // num_slices
else:
if batch_size % self.slice_len != 0:
raise ValueError(
self.SLICE_MISMATCH.format(
batch_size, self.num_slices, self.slice_len
)
)
slice_len = self.slice_len
# num_slices = batch_size // slice_len
while total < batch_size:
try:
data = next(self.dataset_iter)
except StopIteration:
self.dataset_iter = iter(self.dataset)
data = next(self.dataset_iter)
if self.base_path:
data = data[self.base_path]
data = torch.stack(
[_make_tensordict_image_conv(step) for step in data["steps"]]
).contiguous()
_format_data(data, episode)
data = _slice_data(data, slice_len=slice_len, pad_value=self.pad)
data_list.append(data)
total += data.numel()
episode += 1
data = torch.cat(data_list)
if self.truncate:
return data[: index.stop]
return data
def dumps(self, path):
path = Path(path)
state_dict = self.state_dict()
json.dump(state_dict, path / "state_dict.json")
def state_dict(self) -> Dict[str, Any]:
return {
"repo": self.repo,
"split": self.split,
"dataset_id": self.dataset_id,
"shuffle": self.shuffle,
"base_path": self.base_path,
"truncated": self.truncate,
"num_slices": self.num_slices,
"slice_len": self.slice_len,
"pad": self.pad,
}
def loads(self, path):
path = Path(path)
state_dict = json.load(path / "state_dict.json")
self.load_state_dict(state_dict)
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
for key, val in state_dict.items():
setattr(self, key, val)
self._init()
def __len__(self):
raise RuntimeError(
f"{type(self)} does not have a length. Use a downloaded dataset to "
f"access this property."
)
def _slice_data(data: TensorDict, slice_len, pad_value):
if data.shape[-1] < slice_len:
if pad_value is None:
raise RuntimeError(
f"The trajectory length ({data.shape[-1]}) is shorter than the slice length ({slice_len}). "
f"Decrease the slice length or provide a padding value."
)
if pad_value is True:
pad_value = 0
return pad(data, [0, slice_len - data.shape[-1]], value=pad_value)
if data.ndim == 1:
random_range = (
((data.shape[-1] - slice_len) * torch.rand(())).floor().int().item()
)
random_range = slice(random_range, random_range + slice_len)
else:
raise NotImplementedError(data)
data = data[..., random_range]
truncated = data.get(("next", "truncated"))
truncated = torch.index_fill(
truncated,
dim=data.ndim - 1,
value=True,
index=torch.as_tensor(-1, device=truncated.device),
)
done = data.get(("next", "done"))
data.set(("next", "truncated"), truncated)
data.set(("next", "done"), truncated | done)
return data
class _StreamingSampler(Sampler):
def __init__(self):
...
def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]:
return range(batch_size), {}
def _empty(self):
return
def dumps(self, path):
...
def loads(self, path):
...
def state_dict(self) -> Dict[str, Any]:
return {}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
...
OPENX_KEY_MAP = {
"is_first": "is_init",
"is_last": ("next", "done"),
"is_terminal": ("next", "terminated"),
"reward": ("next", "reward"),
}
def _format_data(data: TensorDict, episode: int):
observation_ = data.get("observation")
observation_pad = pad(observation_[1:], [0, 1])
data.set(("next", "observation"), observation_pad)
for key, newkey in OPENX_KEY_MAP.items():
data.rename_key_(key, newkey)
data.set(
("next", "truncated"),
data.get(("next", "done")) & ~data.get(("next", "terminated")),
)
for key in ("done", "terminated", "truncated", "reward"):
data.set(("next", key), data.get(("next", key)).unsqueeze(-1))
if key != "reward":
data.set(key, torch.zeros_like(data.get(("next", key))))
data.set(
"episode", torch.full(data.shape, episode, device=data.device, dtype=torch.int)
)
def _make_tensordict_image_conv(data):
# in some datasets, the images are not well converted.
# before building the tensordict, we load the PIL image and convert it to a tensor
try:
img_bytes = data["observation"]["image"]["bytes"]
if not _has_tv:
raise ImportError(
"the `torchvision` library is required to read the image observation."
)
import torchvision.transforms.v2.functional
from PIL import Image
img = Image.open(io.BytesIO(img_bytes))
tensor = torchvision.transforms.v2.functional.pil_to_tensor(img)
data["observation"]["image"] = tensor
except KeyError:
pass
return make_tensordict(data)