RayReplayBuffer
- class torchrl.data.RayReplayBuffer(*args, ray_init_config: dict[str, Any] | None = None, remote_config: dict[str, Any] | None = None, **kwargs)[source]
A Ray implementation of the Replay Buffer that can be extended and sampled remotely.
- Keyword Arguments:
ray_init_config (dict[str, Any], optiona) – keyword arguments to pass to ray.init().
remote_config (dict[str, Any], optiona) – keyword arguments to pass to cls.as_remote(). Defaults to torchrl.collectors.distributed.ray.DEFAULT_REMOTE_CLASS_CONFIG.
See also
ReplayBuffer
for a list of other keyword arguments.The writer, sampler and storage should be passed as constructors to prevent serialization issues. Transforms constructors should be passed through the transform_factory argument.
Example
>>> import asyncio >>> from tensordict.nn import TensorDictModule >>> from torch import nn >>> from torchrl.collectors.distributed.ray import RayCollector >>> from torchrl.data.replay_buffers.ray_buffer import RayReplayBuffer >>> from torchrl.envs.libs.gym import GymEnv >>> >>> async def main(): ... # 1. Create environment factory ... def env_maker(): ... return GymEnv("Pendulum-v1", device="cpu") ... ... policy = TensorDictModule( ... nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"] ... ) ... ... buffer = RayReplayBuffer() ... ... # 2. Define distributed collector ... remote_config = { ... "num_cpus": 1, ... "num_gpus": 0, ... "memory": 5 * 1024**3, ... "object_store_memory": 2 * 1024**3, ... } ... distributed_collector = RayCollector( ... [env_maker], ... policy, ... total_frames=600, ... frames_per_batch=200, ... remote_configs=remote_config, ... replay_buffer=buffer, ... ) ... ... print("start") ... distributed_collector.start() ... ... while True: ... while not len(buffer): ... print("waiting") ... await asyncio.sleep(1) # Use asyncio.sleep instead of time.sleep ... print("sample", buffer.sample(32)) ... # break at some point ... break ... ... await distributed_collector.async_shutdown() >>> >>> if __name__ == "__main__": ... asyncio.run(main())
- add(*args, **kwargs)[source]
Add a single element to the replay buffer.
- Parameters:
data (Any) – data to be added to the replay buffer
- Returns:
index where the data lives in the replay buffer.
- append_transform(*args, **kwargs)[source]
Appends transform at the end.
Transforms are applied in order when sample is called.
- Parameters:
transform (Transform) – The transform to be appended
- Keyword Arguments:
invert (bool, optional) – if
True
, the transform will be inverted (forward calls will be called during writing and inverse calls during reading). Defaults toFalse
.
Example
>>> rb = ReplayBuffer(storage=LazyMemmapStorage(10), batch_size=4) >>> data = TensorDict({"a": torch.zeros(10)}, [10]) >>> def t(data): ... data += 1 ... return data >>> rb.append_transform(t, invert=True) >>> rb.extend(data) >>> assert (data == 1).all()
- classmethod as_remote(remote_config=None)
Creates an instance of a remote ray class.
- Parameters:
cls (Python Class) – class to be remotely instantiated.
remote_config (dict) – the quantity of CPU cores to reserve for this class. Defaults to torchrl.collectors.distributed.ray.DEFAULT_REMOTE_CLASS_CONFIG.
- Returns:
A function that creates ray remote class instances.
- dumps(path)[source]
Saves the replay buffer on disk at the specified path.
- Parameters:
path (Path or str) – path where to save the replay buffer.
Examples
>>> import tempfile >>> import tqdm >>> from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer >>> from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler >>> import torch >>> from tensordict import TensorDict >>> # Build and populate the replay buffer >>> S = 1_000_000 >>> sampler = PrioritizedSampler(S, 1.1, 1.0) >>> # sampler = RandomSampler() >>> storage = LazyMemmapStorage(S) >>> rb = TensorDictReplayBuffer(storage=storage, sampler=sampler) >>> >>> for _ in tqdm.tqdm(range(100)): ... td = TensorDict({"obs": torch.randn(100, 3, 4), "next": {"obs": torch.randn(100, 3, 4)}, "td_error": torch.rand(100)}, [100]) ... rb.extend(td) ... sample = rb.sample(32) ... rb.update_tensordict_priority(sample) >>> # save and load the buffer >>> with tempfile.TemporaryDirectory() as tmpdir: ... rb.dumps(tmpdir) ... ... sampler = PrioritizedSampler(S, 1.1, 1.0) ... # sampler = RandomSampler() ... storage = LazyMemmapStorage(S) ... rb_load = TensorDictReplayBuffer(storage=storage, sampler=sampler) ... rb_load.loads(tmpdir) ... assert len(rb) == len(rb_load)
- empty()[source]
Empties the replay buffer and reset cursor to 0.
- extend(*args, **kwargs)[source]
Extends the replay buffer with one or more elements contained in an iterable.
If present, the inverse transforms will be called.`
- Parameters:
data (iterable) – collection of data to be added to the replay buffer.
- Returns:
Indices of the data added to the replay buffer.
Warning
extend()
can have an ambiguous signature when dealing with lists of values, which should be interpreted either as PyTree (in which case all elements in the list will be put in a slice in the stored PyTree in the storage) or a list of values to add one at a time. To solve this, TorchRL makes the clear-cut distinction between list and tuple: a tuple will be viewed as a PyTree, a list (at the root level) will be interpreted as a stack of values to add one at a time to the buffer. ForListStorage
instances, only unbound elements can be provided (no PyTrees).
- insert_transform(index: int, transform: Transform, *, invert: bool = False) ReplayBuffer [source]
Inserts transform.
Transforms are executed in order when sample is called.
- Parameters:
index (int) – Position to insert the transform.
transform (Transform) – The transform to be appended
- Keyword Arguments:
invert (bool, optional) – if
True
, the transform will be inverted (forward calls will be called during writing and inverse calls during reading). Defaults toFalse
.
- loads(path)[source]
Loads a replay buffer state at the given path.
The buffer should have matching components and be saved using
dumps()
.- Parameters:
path (Path or str) – path where the replay buffer was saved.
See
dumps()
for more info.
- register_load_hook(hook: Callable[[Any], Any])[source]
Registers a load hook for the storage.
Note
Hooks are currently not serialized when saving a replay buffer: they must be manually re-initialized every time the buffer is created.
- register_save_hook(hook: Callable[[Any], Any])[source]
Registers a save hook for the storage.
Note
Hooks are currently not serialized when saving a replay buffer: they must be manually re-initialized every time the buffer is created.
- sample(*args, **kwargs)[source]
Samples a batch of data from the replay buffer.
Uses Sampler to sample indices, and retrieves them from Storage.
- Parameters:
batch_size (int, optional) – size of data to be collected. If none is provided, this method will sample a batch-size as indicated by the sampler.
return_info (bool) – whether to return info. If True, the result is a tuple (data, info). If False, the result is the data.
- Returns:
A batch of data selected in the replay buffer. A tuple containing this batch and info if return_info flag is set to True.
- property sampler
The sampler of the replay buffer.
The sampler must be an instance of
Sampler
.
- set_sampler(sampler)[source]
Sets a new sampler in the replay buffer and returns the previous sampler.
- set_storage(storage)[source]
Sets a new storage in the replay buffer and returns the previous storage.
- Parameters:
storage (Storage) – the new storage for the buffer.
collate_fn (callable, optional) – if provided, the collate_fn is set to this value. Otherwise it is reset to a default value.
- set_writer(writer)[source]
Sets a new writer in the replay buffer and returns the previous writer.
- property storage
The storage of the replay buffer.
The storage must be an instance of
Storage
.
- property writer
The writer of the replay buffer.
The writer must be an instance of
Writer
.