D4RLExperienceReplay¶
- class torchrl.data.datasets.D4RLExperienceReplay(name, batch_size: int, sampler: Optional[Sampler] = None, writer: Optional[Writer] = None, collate_fn: Optional[Callable] = None, pin_memory: bool = False, prefetch: Optional[int] = None, transform: Optional[Transform] = None, split_trajs: bool = False, from_env: bool = True, use_truncated_as_done: bool = True, **env_kwargs)[source]¶
An Experience replay class for D4RL.
To install D4RL, follow the instructions on the official repo.
The replay buffer contains the env specs under D4RLExperienceReplay.specs.
If present, metadata will be written in
D4RLExperienceReplay.metadata
and excluded from the dataset.The transitions are reconstructed using
done = terminated | truncated
and the("next", "observation")
of"done"
states are zeroed.- Parameters:
name (str) – the name of the D4RL env to get the data from.
batch_size (int) – the batch size to use during sampling.
sampler (Sampler, optional) – the sampler to be used. If none is provided a default RandomSampler() will be used.
writer (Writer, optional) – the writer to be used. If none is provided a default RoundRobinWriter() will be used.
collate_fn (callable, optional) – merges a list of samples to form a mini-batch of Tensor(s)/outputs. Used when using batched loading from a map-style dataset.
pin_memory (bool) – whether pin_memory() should be called on the rb samples.
prefetch (int, optional) – number of next batches to be prefetched using multithreading.
transform (Transform, optional) – Transform to be executed when sample() is called. To chain transforms use the
Compose
class.split_trajs (bool, optional) – if
True
, the trajectories will be split along the first dimension and padded to have a matching shape. To split the trajectories, the"done"
signal will be used, which is recovered viadone = truncated | terminated
. In other words, it is assumed that anytruncated
orterminated
signal is equivalent to the end of a trajectory. For some datasets fromD4RL
, this may not be true. It is up to the user to make accurate choices regarding this usage ofsplit_trajs
. Defaults toFalse
.from_env (bool, optional) –
if
True
,env.get_dataset()
will be used to retrieve the dataset. Otherwised4rl.qlearning_dataset()
will be used. Defaults toTrue
.Note
Using
from_env=False
will provide less data thanfrom_env=True
. For instance, the info keys will be left out. Usually,from_env=False
withterminate_on_end=True
will lead to the same result asfrom_env=True
, with the latter containing meta-data and info entries that the former does not possess.Note
The keys in
from_env=True
andfrom_env=False
may unexpectedly differ. In particular, the"truncated"
key (used to determine the end of an episode) may be absent whenfrom_env=False
but present otherwise, leading to a different slicing whentraj_splits
is enabled.use_truncated_as_done (bool, optional) – if
True
,done = terminated | truncated
. Otherwise, only theterminated
key is used. Defaults toTrue
.**env_kwargs (key-value pairs) – additional kwargs for
d4rl.qlearning_dataset()
. Supportsterminate_on_end
(False
by default) or other kwargs if defined by D4RL library.
Examples
>>> from torchrl.data.datasets.d4rl import D4RLExperienceReplay >>> from torchrl.envs import ObservationNorm >>> data = D4RLExperienceReplay("maze2d-umaze-v1") >>> # we can append transforms to the dataset >>> data.append_transform(ObservationNorm(loc=-1, scale=1.0)) >>> data.sample(128)
- add(data: TensorDictBase) int ¶
Add a single element to the replay buffer.
- Parameters:
data (Any) – data to be added to the replay buffer
- Returns:
index where the data lives in the replay buffer.
- append_transform(transform: Transform) None ¶
Appends transform at the end.
Transforms are applied in order when sample is called.
- Parameters:
transform (Transform) – The transform to be appended
- empty()¶
Empties the replay buffer and reset cursor to 0.
- extend(tensordicts: Union[List, TensorDictBase]) Tensor ¶
Extends the replay buffer with one or more elements contained in an iterable.
If present, the inverse transforms will be called.`
- Parameters:
data (iterable) – collection of data to be added to the replay buffer.
- Returns:
Indices of the data added to the replay buffer.
- insert_transform(index: int, transform: Transform) None ¶
Inserts transform.
Transforms are executed in order when sample is called.
- Parameters:
index (int) – Position to insert the transform.
transform (Transform) – The transform to be appended
- sample(batch_size: Optional[int] = None, return_info: bool = False, include_info: Optional[bool] = None) TensorDictBase ¶
Samples a batch of data from the replay buffer.
Uses Sampler to sample indices, and retrieves them from Storage.
- Parameters:
batch_size (int, optional) – size of data to be collected. If none is provided, this method will sample a batch-size as indicated by the sampler.
return_info (bool) – whether to return info. If True, the result is a tuple (data, info). If False, the result is the data.
- Returns:
A tensordict containing a batch of data selected in the replay buffer. A tuple containing this tensordict and info if return_info flag is set to True.