• Docs >
  • Pooled Embedding Operators
Shortcuts

Pooled Embedding Operators

Stable API

torch.ops.fbgemm.merge_pooled_embeddings(pooled_embeddings, uncat_dim_size, target_device, cat_dim=1) Tensor

Concatenate embedding outputs from different devices (on the same host) on to the target device.

Parameters:
  • pooled_embeddings (List[Tensor]) – A list of embedding outputs from different devices on the same host. Each output has 2 dimensions.

  • uncat_dim_size (int) – The size of the dimension that is not concatenated, i.e., if cat_dim=0, uncat_dim_size is the size of dim 1 and vice versa.

  • target_device (torch.device) – The target device that aggregates all the embedding outputs.

  • cat_dim (int = 1) – The dimension that the tensors are concatenated

Returns:

The concatenated embedding output (2D) on the target device

torch.ops.fbgemm.permute_pooled_embs(pooled_embs, offset_dim_list, permute_list, inv_offset_dim_list, inv_permute_list) Tensor

Permute embedding outputs along the feature dimension.

The embedding output tensor pooled_embs contains the embedding outputs for all features in a batch. It is represented in a 2D format, where the rows are the batch size dimension and the columns are the feature * embedding dimension. Permuting along the feature dimension is essentially permuting along the second dimension (dim 1).

Parameters:
  • pooled_embs (Tensor) – The embedding outputs to permute. Shape is (B_local, total_global_D), where B_local = a local batch size and total_global_D is the total embedding dimension across all features (global)

  • offset_dim_list (Tensor) – The complete cumulative sum of embedding dimensions of all features. Shape is T + 1 where T is the total number of features

  • permute_list (Tensor) – A tensor that describes how each feature is permuted. permute_list[i] indicates that the feature permute_list[i] is permuted to position i

  • inv_offset_dim_list (Tensor) – The complete cumulative sum of inverse embedding dimensions, which are the permuted embedding dimensions. inv_offset_dim_list[i] represents the starting embedding position of feature permute_list[i]

  • inv_permute_list (Tensor) – The inverse permute list, which contains the permuted positions of each feature. inv_permute_list[i] represents the permuted position of feature i

Returns:

Permuted embedding outputs (Tensor). Same shape as pooled_embs

Example:

>>> import torch
>>> from itertools import accumulate
>>>
>>> # Suppose batch size = 3 and there are 3 features
>>> batch_size = 3
>>>
>>> # Embedding dimensions for each feature
>>> embs_dims = torch.tensor([4, 4, 8], dtype=torch.int64, device="cuda")
>>>
>>> # Permute list, i.e., move feature 2 to position 0, move feature 0
>>> # to position 1, so on
>>> permute = torch.tensor([2, 0, 1], dtype=torch.int64, device="cuda")
>>>
>>> # Compute embedding dim offsets
>>> offset_dim_list = torch.tensor([0] + list(accumulate(embs_dims)), dtype=torch.int64, device="cuda")
>>> print(offset_dim_list)
>>>
tensor([ 0,  4,  8, 16], device='cuda:0')
>>>
>>> # Compute inverse embedding dims
>>> inv_embs_dims = [embs_dims[p] for p in permute]
>>> # Compute complete cumulative sum of inverse embedding dims
>>> inv_offset_dim_list = torch.tensor([0] + list(accumulate(inv_embs_dims)), dtype=torch.int64, device="cuda")
>>> print(inv_offset_dim_list)
>>>
tensor([ 0,  8, 12, 16], device='cuda:0')
>>>
>>> # Compute inverse permutes
>>> inv_permute = [0] * len(permute)
>>> for i, p in enumerate(permute):
>>>     inv_permute[p] = i
>>> inv_permute_list = torch.tensor([inv_permute], dtype=torch.int64, device="cuda")
>>> print(inv_permute_list)
>>>
tensor([[1, 2, 0]], device='cuda:0')
>>>
>>> # Generate an example input
>>> pooled_embs = torch.arange(embs_dims.sum().item() * batch_size, dtype=torch.float32, device="cuda").reshape(batch_size, -1)
>>> print(pooled_embs)
>>>
tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
         14., 15.],
        [16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29.,
         30., 31.],
        [32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45.,
         46., 47.]], device='cuda:0')
>>>
>>> torch.ops.fbgemm.permute_pooled_embs_auto_grad(pooled_embs, offset_dim_list, permute, inv_offset_dim_list, inv_permute_list)
>>>
tensor([[ 8.,  9., 10., 11., 12., 13., 14., 15.,  0.,  1.,  2.,  3.,  4.,  5.,
          6.,  7.],
        [24., 25., 26., 27., 28., 29., 30., 31., 16., 17., 18., 19., 20., 21.,
         22., 23.],
        [40., 41., 42., 43., 44., 45., 46., 47., 32., 33., 34., 35., 36., 37.,
         38., 39.]], device='cuda:0')

Other API

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources