Shortcuts

Source code for torchrl.trainers.helpers.replay_buffer

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional

import torch

from torchrl.data.replay_buffers.replay_buffers import (
    ReplayBuffer,
    TensorDictReplayBuffer,
)
from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
from torchrl.data.utils import DEVICE_TYPING


[docs]def make_replay_buffer( device: DEVICE_TYPING, cfg: "DictConfig" # noqa: F821 ) -> ReplayBuffer: # noqa: F821 """Builds a replay buffer using the config built from ReplayArgsConfig.""" device = torch.device(device) if not cfg.prb: sampler = RandomSampler() else: sampler = PrioritizedSampler( max_capacity=cfg.buffer_size, alpha=0.7, beta=0.5, ) buffer = TensorDictReplayBuffer( storage=LazyMemmapStorage( cfg.buffer_size, scratch_dir=cfg.buffer_scratch_dir, # device=device, # when using prefetch, this can overload the GPU memory ), sampler=sampler, pin_memory=device != torch.device("cpu"), prefetch=cfg.buffer_prefetch, batch_size=cfg.batch_size, ) return buffer
@dataclass class ReplayArgsConfig: """Generic Replay Buffer config struct.""" buffer_size: int = 1000000 # buffer size, in number of frames stored. Default=1e6 prb: bool = False # whether a Prioritized replay buffer should be used instead of a more basic circular one. buffer_scratch_dir: Optional[str] = None # directory where the buffer data should be stored. If none is passed, they will be placed in /tmp/ buffer_prefetch: int = 10 # prefetching queue length for the replay buffer

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources