Shortcuts

Source code for torchrl.data.map.hash

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from typing import Callable, List

import torch
from torch.nn import Module


[docs]class BinaryToDecimal(Module): """A Module to convert binaries encoded tensors to decimals. This is a utility class that allow to convert a binary encoding tensor (e.g. `1001`) to its decimal value (e.g. `9`) Args: num_bits (int): the number of bits to use for the bases table. The number of bits must be lower or equal to the input length and the input length must be divisible by ``num_bits``. If ``num_bits`` is lower than the number of bits in the input, the end result will be aggregated on the last dimension using :func:`~torch.sum`. device (torch.device): the device where inputs and outputs are to be expected. dtype (torch.dtype): the output dtype. convert_to_binary (bool, optional): if ``True``, the input to the ``forward`` method will be cast to a binary input using :func:`~torch.heavyside`. Defaults to ``False``. Examples: >>> binary_to_decimal = BinaryToDecimal( ... num_bits=4, device="cpu", dtype=torch.int32, convert_to_binary=True ... ) >>> binary = torch.Tensor([[0, 0, 1, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 10, 0]]) >>> decimal = binary_to_decimal(binary) >>> assert decimal.shape == (2,) >>> assert (decimal == torch.Tensor([3, 2])).all() """ def __init__( self, num_bits: int, device: torch.device, dtype: torch.dtype, convert_to_binary: bool = False, ): super().__init__() self.convert_to_binary = convert_to_binary self.bases = 2 ** torch.arange(num_bits - 1, -1, -1, device=device, dtype=dtype) self.num_bits = num_bits self.zero_tensor = torch.zeros((1,), device=device)
[docs] def forward(self, features: torch.Tensor) -> torch.Tensor: num_features = features.shape[-1] if self.num_bits > num_features: raise ValueError(f"{num_features=} is less than {self.num_bits=}") elif num_features % self.num_bits != 0: raise ValueError(f"{num_features=} is not divisible by {self.num_bits=}") binary_features = ( torch.heaviside(features, self.zero_tensor) if self.convert_to_binary else features ) feature_parts = binary_features.reshape(shape=(-1, self.num_bits)) digits = torch.vmap(torch.dot, (None, 0))( self.bases, feature_parts.to(self.bases.dtype) ) digits = digits.reshape(shape=(-1, features.shape[-1] // self.num_bits)) aggregated_digits = torch.sum(digits, dim=-1) return aggregated_digits
[docs]class SipHash(Module): """A Module to Compute SipHash values for given tensors. A hash function module based on SipHash implementation in python. Input tensors should have shape ``[batch_size, num_features]`` and the output shape will be ``[batch_size]``. Args: as_tensor (bool, optional): if ``True``, the bytes will be turned into integers through the builtin ``hash`` function and mapped to a tensor. Default: ``True``. .. warning:: This module relies on the builtin ``hash`` function. To get reproducible results across runs, the ``PYTHONHASHSEED`` environment variable must be set before the code is run (changing this value during code execution is without effect). Examples: >>> # Assuming we set PYTHONHASHSEED=0 prior to running this code >>> a = torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]) >>> b = a.clone() >>> hash_module = SipHash(as_tensor=True) >>> hash_a = hash_module(a) >>> hash_a tensor([-4669941682990263259, -3778166555168484291, -9122128731510687521]) >>> hash_b = hash_module(b) >>> assert (hash_a == hash_b).all() """ def __init__(self, as_tensor: bool = True): super().__init__() self.as_tensor = as_tensor
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor | List[bytes]: hash_values = [] if x.dtype in (torch.bfloat16,): x = x.to(torch.float16) for x_i in x.detach().cpu().numpy(): hash_value = x_i.tobytes() hash_values.append(hash_value) if not self.as_tensor: return hash_values result = torch.tensor([hash(x) for x in hash_values], dtype=torch.int64) return result
[docs]class RandomProjectionHash(SipHash): """A module that combines random projections with SipHash to get a low-dimensional tensor, easier to embed through :class:`~.SipHash`. This module requires sklearn to be installed. Keyword Args: n_components (int, optional): the low-dimensional number of components of the projections. Defaults to 16. dtype_cast (torch.dtype, optional): the dtype to cast the projection to. Defaults to ``torch.bfloat16``. as_tensor (bool, optional): if ``True``, the bytes will be turned into integers through the builtin ``hash`` function and mapped to a tensor. Default: ``True``. .. warning:: This module relies on the builtin ``hash`` function. To get reproducible results across runs, the ``PYTHONHASHSEED`` environment variable must be set before the code is run (changing this value during code execution is without effect). init_method: TODO """ _N_COMPONENTS_DEFAULT = 16 def __init__( self, *, n_components: int | None = None, dtype_cast=torch.bfloat16, as_tensor: bool = True, init_method: Callable[[torch.Tensor], torch.Tensor | None] | None = None, **kwargs, ): if n_components is None: n_components = self._N_COMPONENTS_DEFAULT super().__init__(as_tensor=as_tensor) self.register_buffer("_n_components", torch.as_tensor(n_components)) self._init = False if init_method is None: init_method = torch.nn.init.normal_ self.init_method = init_method self.dtype_cast = dtype_cast self.register_buffer("transform", torch.nn.UninitializedBuffer()) @property def n_components(self): return self._n_components.item()
[docs] def fit(self, x): """Fits the random projection to the input data.""" self.transform.materialize( (x.shape[-1], self.n_components), dtype=self.dtype_cast, device=x.device ) self.init_method(self.transform) self._init = True
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: if not self._init: self.fit(x) elif not self._init: raise RuntimeError( f"The {type(self).__name__} has not been initialized. Call fit before calling this method." ) x = x.to(self.dtype_cast) @ self.transform return super().forward(x)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources