Shortcuts

OpenMLExperienceReplay

class torchrl.data.datasets.OpenMLExperienceReplay(name: str, batch_size: int, sampler: Optional[Sampler] = None, writer: Optional[Writer] = None, collate_fn: Optional[Callable] = None, pin_memory: bool = False, prefetch: Optional[int] = None, transform: Optional[Transform] = None)[source]

An experience replay for OpenML data.

This class provides an easy entry point for public datasets. See “Dua, D. and Graff, C. (2017) UCI Machine Learning Repository. http://archive.ics.uci.edu/ml

The data is accessed via scikit-learn. Make sure sklearn and pandas are installed before retrieving the data:

$ pip install scikit-learn pandas -U
Parameters:
  • name (str) – the following datasets are supported: "adult_num", "adult_onehot", "mushroom_num", "mushroom_onehot", "covertype", "shuttle" and "magic".

  • batch_size (int) – the batch size to use during sampling.

  • sampler (Sampler, optional) – the sampler to be used. If none is provided a default RandomSampler() will be used.

  • writer (Writer, optional) – the writer to be used. If none is provided a default RoundRobinWriter() will be used.

  • collate_fn (callable, optional) – merges a list of samples to form a mini-batch of Tensor(s)/outputs. Used when using batched loading from a map-style dataset.

  • pin_memory (bool) – whether pin_memory() should be called on the rb samples.

  • prefetch (int, optional) – number of next batches to be prefetched using multithreading.

  • transform (Transform, optional) – Transform to be executed when sample() is called. To chain transforms use the Compose class.

add(data: TensorDictBase) int

Add a single element to the replay buffer.

Parameters:

data (Any) – data to be added to the replay buffer

Returns:

index where the data lives in the replay buffer.

append_transform(transform: Transform) None

Appends transform at the end.

Transforms are applied in order when sample is called.

Parameters:

transform (Transform) – The transform to be appended

empty()

Empties the replay buffer and reset cursor to 0.

extend(tensordicts: Union[List, TensorDictBase]) Tensor

Extends the replay buffer with one or more elements contained in an iterable.

If present, the inverse transforms will be called.`

Parameters:

data (iterable) – collection of data to be added to the replay buffer.

Returns:

Indices of the data added to the replay buffer.

insert_transform(index: int, transform: Transform) None

Inserts transform.

Transforms are executed in order when sample is called.

Parameters:
  • index (int) – Position to insert the transform.

  • transform (Transform) – The transform to be appended

sample(batch_size: Optional[int] = None, return_info: bool = False, include_info: Optional[bool] = None) TensorDictBase

Samples a batch of data from the replay buffer.

Uses Sampler to sample indices, and retrieves them from Storage.

Parameters:
  • batch_size (int, optional) – size of data to be collected. If none is provided, this method will sample a batch-size as indicated by the sampler.

  • return_info (bool) – whether to return info. If True, the result is a tuple (data, info). If False, the result is the data.

Returns:

A tensordict containing a batch of data selected in the replay buffer. A tuple containing this tensordict and info if return_info flag is set to True.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources