Shortcuts

Source code for torchrl.data.replay_buffers.samplers

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Any, Dict, Tuple, Union

import numpy as np
import torch

from torchrl._torchrl import (
    MinSegmentTreeFp32,
    MinSegmentTreeFp64,
    SumSegmentTreeFp32,
    SumSegmentTreeFp64,
)

from .storages import Storage
from .utils import _to_numpy, INT_CLASSES

_EMPTY_STORAGE_ERROR = "Cannot sample from an empty storage."


[docs]class Sampler(ABC): """A generic sampler base class for composable Replay Buffers.""" @abstractmethod def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]: ... def add(self, index: int) -> None: return def extend(self, index: torch.Tensor) -> None: return def update_priority( self, index: Union[int, torch.Tensor], priority: Union[float, torch.Tensor] ) -> dict: return def mark_update(self, index: Union[int, torch.Tensor]) -> None: return @property def default_priority(self) -> float: return 1.0 def state_dict(self) -> Dict[str, Any]: return {} def load_state_dict(self, state_dict: Dict[str, Any]) -> None: return @property def ran_out(self) -> bool: # by default, samplers never run out return False @abstractmethod def _empty(self): ...
[docs]class RandomSampler(Sampler): """A uniformly random sampler for composable replay buffers. Args: batch_size (int, optional): if provided, the batch size to be used by the replay buffer when calling :meth:`~.ReplayBuffer.sample`. """ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]: if len(storage) == 0: raise RuntimeError(_EMPTY_STORAGE_ERROR) index = torch.randint(0, len(storage), (batch_size,)) return index, {} def _empty(self): pass
[docs]class SamplerWithoutReplacement(Sampler): """A data-consuming sampler that ensures that the same sample is not present in consecutive batches. Args: drop_last (bool, optional): if ``True``, the last incomplete sample (if any) will be dropped. If False, this last sample will be kept and (unlike with torch dataloaders) completed with other samples from a fresh indices permutation. *Caution*: If the size of the storage changes in between two calls, the samples will be re-shuffled (as we can't generally keep track of which samples have been sampled before and which haven't). Similarly, it is expected that the storage content remains the same in between two calls, but this is not enforced. When the sampler reaches the end of the list of available indices, a new sample order will be generated and the resulting indices will be completed with this new draw, which can lead to duplicated indices, unless the :obj:`drop_last` argument is set to ``True``. """ def __init__(self, drop_last: bool = False): self._sample_list = None self.len_storage = 0 self.drop_last = drop_last self._ran_out = False def _single_sample(self, len_storage, batch_size): index = self._sample_list[:batch_size] self._sample_list = self._sample_list[batch_size:] # check if we have enough elements for one more batch, assuming same batch size # will be used each time sample is called if self._sample_list.numel() == 0 or ( self.drop_last and len(self._sample_list) < batch_size ): self._ran_out = True self._sample_list = torch.randperm(len_storage) else: self._ran_out = False return index def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]: len_storage = len(storage) if len_storage == 0: raise RuntimeError(_EMPTY_STORAGE_ERROR) if not len_storage: raise RuntimeError("An empty storage was passed") if self.len_storage != len_storage or self._sample_list is None: self._sample_list = torch.randperm(len_storage) if len_storage < batch_size and self.drop_last: raise ValueError( f"The batch size ({batch_size}) is greater than the storage capacity ({len_storage}). " "This makes it impossible to return a sample without repeating indices. " "Consider changing the sampler class or turn the 'drop_last' argument to False." ) self.len_storage = len_storage index = self._single_sample(len_storage, batch_size) # we 'always' return the indices. The 'drop_last' just instructs the # sampler to turn to 'ran_out = True` whenever the next sample # will be too short. This will be read by the replay buffer # as a signal for an early break of the __iter__(). return index, {} @property def ran_out(self): return self._ran_out @ran_out.setter def ran_out(self, value): self._ran_out = value def _empty(self): self._sample_list = None self.len_storage = 0 self._ran_out = False
[docs]class PrioritizedSampler(Sampler): """Prioritized sampler for replay buffer. Presented in "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay." (https://arxiv.org/abs/1511.05952) Args: alpha (float): exponent α determines how much prioritization is used, with α = 0 corresponding to the uniform case. beta (float): importance sampling negative exponent. eps (float, optional): delta added to the priorities to ensure that the buffer does not contain null priorities. Defaults to 1e-8. reduction (str, optional): the reduction method for multidimensional tensordicts (ie stored trajectories). Can be one of "max", "min", "median" or "mean". """ def __init__( self, max_capacity: int, alpha: float, beta: float, eps: float = 1e-8, dtype: torch.dtype = torch.float, reduction: str = "max", ) -> None: if alpha <= 0: raise ValueError( f"alpha must be strictly greater than 0, got alpha={alpha}" ) if beta < 0: raise ValueError(f"beta must be greater or equal to 0, got beta={beta}") self._max_capacity = max_capacity self._alpha = alpha self._beta = beta self._eps = eps self.reduction = reduction self.dtype = dtype self._init() def _init(self): if self.dtype in (torch.float, torch.FloatType, torch.float32): self._sum_tree = SumSegmentTreeFp32(self._max_capacity) self._min_tree = MinSegmentTreeFp32(self._max_capacity) elif self.dtype in (torch.double, torch.DoubleTensor, torch.float64): self._sum_tree = SumSegmentTreeFp64(self._max_capacity) self._min_tree = MinSegmentTreeFp64(self._max_capacity) else: raise NotImplementedError( f"dtype {self.dtype} not supported by PrioritizedSampler" ) self._max_priority = 1.0 def _empty(self): self._init() @property def default_priority(self) -> float: return (self._max_priority + self._eps) ** self._alpha def sample(self, storage: Storage, batch_size: int) -> torch.Tensor: if len(storage) == 0: raise RuntimeError(_EMPTY_STORAGE_ERROR) p_sum = self._sum_tree.query(0, len(storage)) p_min = self._min_tree.query(0, len(storage)) if p_sum <= 0: raise RuntimeError("negative p_sum") if p_min <= 0: raise RuntimeError("negative p_min") mass = np.random.uniform(0.0, p_sum, size=batch_size) index = self._sum_tree.scan_lower_bound(mass) if not isinstance(index, np.ndarray): index = np.array([index]) if isinstance(index, torch.Tensor): index.clamp_max_(len(storage) - 1) else: index = np.clip(index, None, len(storage) - 1) weight = self._sum_tree[index] # Importance sampling weight formula: # w_i = (p_i / sum(p) * N) ^ (-beta) # weight_i = w_i / max(w) # weight_i = (p_i / sum(p) * N) ^ (-beta) / # ((min(p) / sum(p) * N) ^ (-beta)) # weight_i = ((p_i / sum(p) * N) / (min(p) / sum(p) * N)) ^ (-beta) # weight_i = (p_i / min(p)) ^ (-beta) # weight = np.power(weight / (p_min + self._eps), -self._beta) weight = np.power(weight / p_min, -self._beta) return index, {"_weight": weight} def _add_or_extend(self, index: Union[int, torch.Tensor]) -> None: priority = self.default_priority if not ( isinstance(priority, float) or len(priority) == 1 or len(priority) == len(index) ): raise RuntimeError( "priority should be a scalar or an iterable of the same " "length as index" ) self._sum_tree[index] = priority self._min_tree[index] = priority def add(self, index: int) -> None: super().add(index) self._add_or_extend(index) def extend(self, index: torch.Tensor) -> None: super().extend(index) self._add_or_extend(index)
[docs] def update_priority( self, index: Union[int, torch.Tensor], priority: Union[float, torch.Tensor] ) -> None: """Updates the priority of the data pointed by the index. Args: index (int or torch.Tensor): indexes of the priorities to be updated. priority (Number or torch.Tensor): new priorities of the indexed elements. """ if isinstance(index, INT_CLASSES): if not isinstance(priority, float): if len(priority) != 1: raise RuntimeError( f"priority length should be 1, got {len(priority)}" ) priority = priority.item() else: if not ( isinstance(priority, float) or len(priority) == 1 or len(index) == len(priority) ): raise RuntimeError( "priority should be a number or an iterable of the same " "length as index" ) index = _to_numpy(index) priority = _to_numpy(priority) self._max_priority = max(self._max_priority, np.max(priority)) priority = np.power(priority + self._eps, self._alpha) self._sum_tree[index] = priority self._min_tree[index] = priority
def mark_update(self, index: Union[int, torch.Tensor]) -> None: self.update_priority(index, self.default_priority) def state_dict(self) -> Dict[str, Any]: return { "_alpha": self._alpha, "_beta": self._beta, "_eps": self._eps, "_max_priority": self._max_priority, "_sum_tree": deepcopy(self._sum_tree), "_min_tree": deepcopy(self._min_tree), } def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self._alpha = state_dict["_alpha"] self._beta = state_dict["_beta"] self._eps = state_dict["_eps"] self._max_priority = state_dict["_max_priority"] self._sum_tree = state_dict.pop("_sum_tree") self._min_tree = state_dict.pop("_min_tree")

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources