Source code for torchrl.data.replay_buffers.samplers
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Any, Dict, Tuple, Union
import numpy as np
import torch
from torchrl._torchrl import (
MinSegmentTreeFp32,
MinSegmentTreeFp64,
SumSegmentTreeFp32,
SumSegmentTreeFp64,
)
from .storages import Storage
from .utils import _to_numpy, INT_CLASSES
_EMPTY_STORAGE_ERROR = "Cannot sample from an empty storage."
[docs]class Sampler(ABC):
"""A generic sampler base class for composable Replay Buffers."""
@abstractmethod
def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]:
...
def add(self, index: int) -> None:
return
def extend(self, index: torch.Tensor) -> None:
return
def update_priority(
self, index: Union[int, torch.Tensor], priority: Union[float, torch.Tensor]
) -> dict:
return
def mark_update(self, index: Union[int, torch.Tensor]) -> None:
return
@property
def default_priority(self) -> float:
return 1.0
def state_dict(self) -> Dict[str, Any]:
return {}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
return
@property
def ran_out(self) -> bool:
# by default, samplers never run out
return False
@abstractmethod
def _empty(self):
...
[docs]class RandomSampler(Sampler):
"""A uniformly random sampler for composable replay buffers.
Args:
batch_size (int, optional): if provided, the batch size to be used by
the replay buffer when calling :meth:`~.ReplayBuffer.sample`.
"""
def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]:
if len(storage) == 0:
raise RuntimeError(_EMPTY_STORAGE_ERROR)
index = torch.randint(0, len(storage), (batch_size,))
return index, {}
def _empty(self):
pass
[docs]class SamplerWithoutReplacement(Sampler):
"""A data-consuming sampler that ensures that the same sample is not present in consecutive batches.
Args:
drop_last (bool, optional): if ``True``, the last incomplete sample (if any) will be dropped.
If False, this last sample will be kept and (unlike with torch dataloaders)
completed with other samples from a fresh indices permutation.
*Caution*: If the size of the storage changes in between two calls, the samples will be re-shuffled
(as we can't generally keep track of which samples have been sampled before and which haven't).
Similarly, it is expected that the storage content remains the same in between two calls,
but this is not enforced.
When the sampler reaches the end of the list of available indices, a new sample order
will be generated and the resulting indices will be completed with this new draw, which
can lead to duplicated indices, unless the :obj:`drop_last` argument is set to ``True``.
"""
def __init__(self, drop_last: bool = False):
self._sample_list = None
self.len_storage = 0
self.drop_last = drop_last
self._ran_out = False
def _single_sample(self, len_storage, batch_size):
index = self._sample_list[:batch_size]
self._sample_list = self._sample_list[batch_size:]
# check if we have enough elements for one more batch, assuming same batch size
# will be used each time sample is called
if self._sample_list.numel() == 0 or (
self.drop_last and len(self._sample_list) < batch_size
):
self._ran_out = True
self._sample_list = torch.randperm(len_storage)
else:
self._ran_out = False
return index
def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]:
len_storage = len(storage)
if len_storage == 0:
raise RuntimeError(_EMPTY_STORAGE_ERROR)
if not len_storage:
raise RuntimeError("An empty storage was passed")
if self.len_storage != len_storage or self._sample_list is None:
self._sample_list = torch.randperm(len_storage)
if len_storage < batch_size and self.drop_last:
raise ValueError(
f"The batch size ({batch_size}) is greater than the storage capacity ({len_storage}). "
"This makes it impossible to return a sample without repeating indices. "
"Consider changing the sampler class or turn the 'drop_last' argument to False."
)
self.len_storage = len_storage
index = self._single_sample(len_storage, batch_size)
# we 'always' return the indices. The 'drop_last' just instructs the
# sampler to turn to 'ran_out = True` whenever the next sample
# will be too short. This will be read by the replay buffer
# as a signal for an early break of the __iter__().
return index, {}
@property
def ran_out(self):
return self._ran_out
@ran_out.setter
def ran_out(self, value):
self._ran_out = value
def _empty(self):
self._sample_list = None
self.len_storage = 0
self._ran_out = False
[docs]class PrioritizedSampler(Sampler):
"""Prioritized sampler for replay buffer.
Presented in "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015.
Prioritized experience replay."
(https://arxiv.org/abs/1511.05952)
Args:
alpha (float): exponent α determines how much prioritization is used,
with α = 0 corresponding to the uniform case.
beta (float): importance sampling negative exponent.
eps (float, optional): delta added to the priorities to ensure that the buffer
does not contain null priorities. Defaults to 1e-8.
reduction (str, optional): the reduction method for multidimensional
tensordicts (ie stored trajectories). Can be one of "max", "min",
"median" or "mean".
"""
def __init__(
self,
max_capacity: int,
alpha: float,
beta: float,
eps: float = 1e-8,
dtype: torch.dtype = torch.float,
reduction: str = "max",
) -> None:
if alpha <= 0:
raise ValueError(
f"alpha must be strictly greater than 0, got alpha={alpha}"
)
if beta < 0:
raise ValueError(f"beta must be greater or equal to 0, got beta={beta}")
self._max_capacity = max_capacity
self._alpha = alpha
self._beta = beta
self._eps = eps
self.reduction = reduction
self.dtype = dtype
self._init()
def _init(self):
if self.dtype in (torch.float, torch.FloatType, torch.float32):
self._sum_tree = SumSegmentTreeFp32(self._max_capacity)
self._min_tree = MinSegmentTreeFp32(self._max_capacity)
elif self.dtype in (torch.double, torch.DoubleTensor, torch.float64):
self._sum_tree = SumSegmentTreeFp64(self._max_capacity)
self._min_tree = MinSegmentTreeFp64(self._max_capacity)
else:
raise NotImplementedError(
f"dtype {self.dtype} not supported by PrioritizedSampler"
)
self._max_priority = 1.0
def _empty(self):
self._init()
@property
def default_priority(self) -> float:
return (self._max_priority + self._eps) ** self._alpha
def sample(self, storage: Storage, batch_size: int) -> torch.Tensor:
if len(storage) == 0:
raise RuntimeError(_EMPTY_STORAGE_ERROR)
p_sum = self._sum_tree.query(0, len(storage))
p_min = self._min_tree.query(0, len(storage))
if p_sum <= 0:
raise RuntimeError("negative p_sum")
if p_min <= 0:
raise RuntimeError("negative p_min")
mass = np.random.uniform(0.0, p_sum, size=batch_size)
index = self._sum_tree.scan_lower_bound(mass)
if not isinstance(index, np.ndarray):
index = np.array([index])
if isinstance(index, torch.Tensor):
index.clamp_max_(len(storage) - 1)
else:
index = np.clip(index, None, len(storage) - 1)
weight = self._sum_tree[index]
# Importance sampling weight formula:
# w_i = (p_i / sum(p) * N) ^ (-beta)
# weight_i = w_i / max(w)
# weight_i = (p_i / sum(p) * N) ^ (-beta) /
# ((min(p) / sum(p) * N) ^ (-beta))
# weight_i = ((p_i / sum(p) * N) / (min(p) / sum(p) * N)) ^ (-beta)
# weight_i = (p_i / min(p)) ^ (-beta)
# weight = np.power(weight / (p_min + self._eps), -self._beta)
weight = np.power(weight / p_min, -self._beta)
return index, {"_weight": weight}
def _add_or_extend(self, index: Union[int, torch.Tensor]) -> None:
priority = self.default_priority
if not (
isinstance(priority, float)
or len(priority) == 1
or len(priority) == len(index)
):
raise RuntimeError(
"priority should be a scalar or an iterable of the same "
"length as index"
)
self._sum_tree[index] = priority
self._min_tree[index] = priority
def add(self, index: int) -> None:
super().add(index)
self._add_or_extend(index)
def extend(self, index: torch.Tensor) -> None:
super().extend(index)
self._add_or_extend(index)
[docs] def update_priority(
self, index: Union[int, torch.Tensor], priority: Union[float, torch.Tensor]
) -> None:
"""Updates the priority of the data pointed by the index.
Args:
index (int or torch.Tensor): indexes of the priorities to be
updated.
priority (Number or torch.Tensor): new priorities of the
indexed elements.
"""
if isinstance(index, INT_CLASSES):
if not isinstance(priority, float):
if len(priority) != 1:
raise RuntimeError(
f"priority length should be 1, got {len(priority)}"
)
priority = priority.item()
else:
if not (
isinstance(priority, float)
or len(priority) == 1
or len(index) == len(priority)
):
raise RuntimeError(
"priority should be a number or an iterable of the same "
"length as index"
)
index = _to_numpy(index)
priority = _to_numpy(priority)
self._max_priority = max(self._max_priority, np.max(priority))
priority = np.power(priority + self._eps, self._alpha)
self._sum_tree[index] = priority
self._min_tree[index] = priority
def mark_update(self, index: Union[int, torch.Tensor]) -> None:
self.update_priority(index, self.default_priority)
def state_dict(self) -> Dict[str, Any]:
return {
"_alpha": self._alpha,
"_beta": self._beta,
"_eps": self._eps,
"_max_priority": self._max_priority,
"_sum_tree": deepcopy(self._sum_tree),
"_min_tree": deepcopy(self._min_tree),
}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self._alpha = state_dict["_alpha"]
self._beta = state_dict["_beta"]
self._eps = state_dict["_eps"]
self._max_priority = state_dict["_max_priority"]
self._sum_tree = state_dict.pop("_sum_tree")
self._min_tree = state_dict.pop("_min_tree")