Shortcuts

torchrec.sparse

Torchrec Jagged Tensors

It has 3 classes: JaggedTensor, KeyedJaggedTensor, KeyedTensor.

JaggedTensor

It represents an (optionally weighted) jagged tensor. A JaggedTensor is a tensor with a jagged dimension which is dimension whose slices may be of different lengths. See KeyedJaggedTensor docstring for full example and further information.

KeyedJaggedTensor

KeyedJaggedTensor has additional “Key” information. Keyed on first dimesion, and jagged on last dimension. Please refer to KeyedJaggedTensor docstring for full example and further information.

KeyedTensor

KeyedTensor holds a concatenated list of dense tensors each of which can be accessed by a key. Keyed dimension can be variable length (length_per_key). Common use cases uses include storage of pooled embeddings of different dimensions. Please refer to KeyedTensor docstring for full example and further information.

torchrec.sparse.jagged_tensor

class torchrec.sparse.jagged_tensor.ComputeJTDictToKJT(*args, **kwargs)

Bases: Module

Converts a dict of JaggedTensors to KeyedJaggedTensor. Args:

Example: passing in jt_dict

{

“Feature0”: JaggedTensor([[V0,V1],None,V2]), “Feature1”: JaggedTensor([V3,V4,[V5,V6,V7]]),

}

Returns:: kjt with content: # 0 1 2 <– dim_1 # “Feature0” [V0,V1] None [V2] # “Feature1” [V3] [V4] [V5,V6,V7] # ^ # dim_0

forward(jt_dict: Dict[str, JaggedTensor]) KeyedJaggedTensor
Parameters:

jt_dict – a dict of JaggedTensor

Returns:

KeyedJaggedTensor

training: bool
class torchrec.sparse.jagged_tensor.ComputeKJTToJTDict(*args, **kwargs)

Bases: Module

Converts a KeyedJaggedTensor to a dict of JaggedTensors.

Args:

Example::

# 0 1 2 <– dim_1 # “Feature0” [V0,V1] None [V2] # “Feature1” [V3] [V4] [V5,V6,V7] # ^ # dim_0

would return

{

“Feature0”: JaggedTensor([[V0,V1],None,V2]), “Feature1”: JaggedTensor([V3,V4,[V5,V6,V7]]),

}

forward(keyed_jagged_tensor: KeyedJaggedTensor) Dict[str, JaggedTensor]

Converts a KeyedJaggedTensor into a dict of JaggedTensors.

Parameters:

keyed_jagged_tensor (KeyedJaggedTensor) – tensor to convert

Returns:

Dict[str, JaggedTensor]

training: bool
class torchrec.sparse.jagged_tensor.JaggedTensor(*args, **kwargs)

Bases: Pipelineable

Represents an (optionally weighted) jagged tensor.

A JaggedTensor is a tensor with a jagged dimension which is dimension whose slices may be of different lengths. See KeyedJaggedTensor for full example.

Implementation is torch.jit.script-able.

Note

We will NOT do input validation as it’s expensive, you should always pass in the valid lengths, offsets, etc.

Parameters:
  • values (torch.Tensor) – values tensor in dense representation.

  • weights (Optional[torch.Tensor]) – if values have weights. Tensor with same shape as values.

  • lengths (Optional[torch.Tensor]) – jagged slices, represented as lengths.

  • offsets (Optional[torch.Tensor]) – jagged slices, represented as cumulative offsets.

static empty(is_weighted: bool = False, device: Optional[device] = None, values_dtype: Optional[dtype] = None, weights_dtype: Optional[dtype] = None, lengths_dtype: dtype = torch.int32) JaggedTensor
static from_dense(values: List[Tensor], weights: Optional[List[Tensor]] = None) JaggedTensor

Constructs JaggedTensor from dense values/weights of shape (B, N,).

Note that lengths and offsets are still of shape (B,).

Parameters:
  • values (List[torch.Tensor]) – a list of tensors for dense representation

  • weights (Optional[List[torch.Tensor]]) – if values have weights, tensor with the same shape as values.

Returns:

JaggedTensor created from 2D dense tensor.

Return type:

JaggedTensor

Example:

values = [
    torch.Tensor([1.0]),
    torch.Tensor(),
    torch.Tensor([7.0, 8.0]),
    torch.Tensor([10.0, 11.0, 12.0]),
]
weights = [
    torch.Tensor([1.0]),
    torch.Tensor(),
    torch.Tensor([7.0, 8.0]),
    torch.Tensor([10.0, 11.0, 12.0]),
]
j1 = JaggedTensor.from_dense(
    values=values,
    weights=weights,
)

# j1 = [[1.0], [], [7.0], [8.0], [10.0, 11.0, 12.0]]
static from_dense_lengths(values: Tensor, lengths: Tensor, weights: Optional[Tensor] = None) JaggedTensor

Constructs JaggedTensor from dense values/weights of shape (B, N,).

Note that lengths is still of shape (B,).

lengths() Tensor
lengths_or_none() Optional[Tensor]
offsets() Tensor
offsets_or_none() Optional[Tensor]
record_stream(stream: Stream) None

See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html

to(device: device, non_blocking: bool = False) JaggedTensor

Please be aware that according to https://pytorch.org/docs/stable/generated/torch.Tensor.to.html, to might return self or a copy of self. So please remember to use to with the assignment operator, for example, in = in.to(new_device).

to_dense() List[Tensor]

Constructs a dense-representation of the JT’s values.

Returns:

list of tensors.

Return type:

List[torch.Tensor]

Example:

values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8])
jt = JaggedTensor(values=values, offsets=offsets)

values_list = jt.to_dense()

# values_list = [
#     torch.tensor([1.0, 2.0]),
#     torch.tensor([]),
#     torch.tensor([3.0]),
#     torch.tensor([4.0]),
#     torch.tensor([5.0]),
#     torch.tensor([6.0, 7.0, 8.0]),
# ]
to_dense_weights() Optional[List[Tensor]]

Constructs a dense-representation of the JT’s weights.

Returns:

list of tensors, None if no weights.

Return type:

Optional[List[torch.Tensor]]

Example:

values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
weights = torch.Tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8])
offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8])
jt = JaggedTensor(values=values, weights=weights, offsets=offsets)

weights_list = jt.to_dense_weights()

# weights_list = [
#     torch.tensor([0.1, 0.2]),
#     torch.tensor([]),
#     torch.tensor([0.3]),
#     torch.tensor([0.4]),
#     torch.tensor([0.5]),
#     torch.tensor([0.6, 0.7, 0.8]),
# ]
to_padded_dense(desired_length: Optional[int] = None, padding_value: float = 0.0) Tensor

Constructs a 2D dense tensor from the JT’s values of shape (B, N,).

Note that B is the length of self.lengths() and N is the longest feature length or desired_length.

If desired_length > length we will pad with padding_value, otherwise we will select the last value at desired_length.

Parameters:
  • desired_length (int) – the length of the tensor.

  • padding_value (float) – padding value if we need to pad.

Returns:

2d dense tensor.

Return type:

torch.Tensor

Example:

values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8])
jt = JaggedTensor(values=values, offsets=offsets)

dt = jt.to_padded_dense(
    desired_length=2,
    padding_value=10.0,
)

# dt = [
#     [1.0, 2.0],
#     [10.0, 10.0],
#     [3.0, 10.0],
#     [4.0, 10.0],
#     [5.0, 10.0],
#     [6.0, 7.0],
# ]
to_padded_dense_weights(desired_length: Optional[int] = None, padding_value: float = 0.0) Optional[Tensor]

Constructs a 2D dense tensor from the JT’s weights of shape (B, N,).

Note that B is the length of self.lengths() and N is the longest feature length or desired_length.

If desired_length > length we will pad with padding_value, otherwise we will select the last value at desired_length.

Parameters:
  • desired_length (int) – the length of the tensor.

  • padding_value (float) – padding value if we need to pad.

Returns:

2d dense tensor, None if no weights.

Return type:

Optional[torch.Tensor]

Example:

values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
weights = torch.Tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8])
offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8])
jt = JaggedTensor(values=values, weights=weights, offsets=offsets)

d_wt = jt.to_padded_dense_weights(
    desired_length=2,
    padding_value=1.0,
)

# d_wt = [
#     [0.1, 0.2],
#     [1.0, 1.0],
#     [0.3, 1.0],
#     [0.4, 1.0],
#     [0.5, 1.0],
#     [0.6, 0.7],
# ]
values() Tensor
weights() Tensor
weights_or_none() Optional[Tensor]
class torchrec.sparse.jagged_tensor.JaggedTensorMeta(name, bases, namespace, **kwargs)

Bases: ABCMeta, ProxyableClassMeta

class torchrec.sparse.jagged_tensor.KeyedJaggedTensor(*args, **kwargs)

Bases: Pipelineable

Represents an (optionally weighted) keyed jagged tensor.

A KeyedJaggedTensor is a tensor with a jagged dimension which is dimension whose slices may be of different lengths. Keyed on first dimension and jagged on the last dimension.

Implementation is torch.jit.script-able.

Parameters:
  • keys (List[str]) – keys to the jagged Tensor.

  • values (torch.Tensor) – values tensor in dense representation.

  • weights (Optional[torch.Tensor]) – if the values have weights. Tensor with the same shape as values.

  • lengths (Optional[torch.Tensor]) – jagged slices, represented as lengths.

  • offsets (Optional[torch.Tensor]) – jagged slices, represented as cumulative offsets.

  • stride (Optional[int]) – number of examples per batch.

  • stride_per_key_per_rank (Optional[List[List[int]]]) – batch size (number of examples) per key per rank, with the outer list representing the keys and the inner list representing the values. Each value in the inner list represents the number of examples in the batch from the rank of its index in a distributed context.

  • length_per_key (Optional[List[int]]) – start length for each key.

  • offset_per_key (Optional[List[int]]) – start offset for each key and final offset.

  • index_per_key (Optional[Dict[str, int]]) – index for each key.

  • jt_dict (Optional[Dict[str, JaggedTensor]]) –

  • inverse_indices (Optional[Tuple[List[str], torch.Tensor]]) – inverse indices to expand deduplicated embedding output for variable stride per key.

Example:

#              0       1        2  <-- dim_1
# "Feature0"   [V0,V1] None    [V2]
# "Feature1"   [V3]    [V4]    [V5,V6,V7]
#   ^
#  dim_0

dim_0: keyed dimension (ie. `Feature0`, `Feature1`)
dim_1: optional second dimension (ie. batch size)
dim_2: The jagged dimension which has slice lengths between 0-3 in the above example

# We represent this data with following inputs:

values: torch.Tensor = [V0, V1, V2, V3, V4, V5, V6, V7]  # V == any tensor datatype
weights: torch.Tensor = [W0, W1, W2, W3, W4, W5, W6, W7]  # W == any tensor datatype
lengths: torch.Tensor = [2, 0, 1, 1, 1, 3]  # representing the jagged slice
offsets: torch.Tensor = [0, 2, 2, 3, 4, 5, 8]  # offsets from 0 for each jagged slice
keys: List[str] = ["Feature0", "Feature1"]  # correspond to each value of dim_0
index_per_key: Dict[str, int] = {"Feature0": 0, "Feature1": 1}  # index for each key
offset_per_key: List[int] = [0, 3, 8]  # start offset for each key and final offset
static concat(kjt_list: List[KeyedJaggedTensor]) KeyedJaggedTensor
device() device
static dist_init(keys: List[str], tensors: List[Tensor], variable_stride_per_key: bool, num_workers: int, recat: Optional[Tensor], stride_per_rank: Optional[List[int]], stagger: int = 1) KeyedJaggedTensor
dist_labels() List[str]
dist_splits(key_splits: List[int]) List[List[int]]
dist_tensors() List[Tensor]
static empty(is_weighted: bool = False, device: Optional[device] = None, values_dtype: Optional[dtype] = None, weights_dtype: Optional[dtype] = None, lengths_dtype: dtype = torch.int32) KeyedJaggedTensor
static empty_like(kjt: KeyedJaggedTensor) KeyedJaggedTensor
flatten_lengths() KeyedJaggedTensor
static from_jt_dict(jt_dict: Dict[str, JaggedTensor]) KeyedJaggedTensor

Constructs a KeyedJaggedTensor from a Dict[str, JaggedTensor], but this function will ONLY work if the JaggedTensors all have the same “implicit” batch_size dimension.

Basically, we can visualize JaggedTensors as 2-D tensors of the format of [batch_size x variable_feature_dim]. In case, we have some batch without a feature value, the input JaggedTensor could just not include any values.

But KeyedJaggedTensor (by default) typically pad “None” so that all the JaggedTensors stored in the KeyedJaggedTensor have the same batch_size dimension. That is, in the case, the JaggedTensor input didn’t automatically pad for the empty batches, this function would error / not work.

Consider the visualization of the following KeyedJaggedTensor: # 0 1 2 <– dim_1 # “Feature0” [V0,V1] None [V2] # “Feature1” [V3] [V4] [V5,V6,V7] # ^ # dim_0

Notice that the inputs for this KeyedJaggedTensor would have looked like:

values: torch.Tensor = [V0, V1, V2, V3, V4, V5, V6, V7] # V == any tensor datatype weights: torch.Tensor = [W0, W1, W2, W3, W4, W5, W6, W7] # W == any tensor datatype lengths: torch.Tensor = [2, 0, 1, 1, 1, 3] # representing the jagged slice offsets: torch.Tensor = [0, 2, 2, 3, 4, 5, 8] # offsets from 0 for each jagged slice keys: List[str] = [“Feature0”, “Feature1”] # correspond to each value of dim_0 index_per_key: Dict[str, int] = {“Feature0”: 0, “Feature1”: 1} # index for each key offset_per_key: List[int] = [0, 3, 8] # start offset for each key and final offset

Now if the input jt_dict = {

# “Feature0” [V0,V1] [V2] # “Feature1” [V3] [V4] [V5,V6,V7]

} and the “None” is left out from each JaggedTensor, then this function would fail as we would not correctly be able to pad “None” as it does not technically know the correct batch / place to pad within the JaggedTensor.

Essentially, the lengths Tensor inferred by this function would be [2, 1, 1, 1, 3] indicating variable batch_size dim_1 violates the existing assumption / precondition that KeyedJaggedTensor’s should have fixed batch_size dimension.

static from_lengths_sync(keys: List[str], values: Tensor, lengths: Tensor, weights: Optional[Tensor] = None, stride: Optional[int] = None, stride_per_key_per_rank: Optional[List[List[int]]] = None, inverse_indices: Optional[Tuple[List[str], Tensor]] = None) KeyedJaggedTensor
static from_offsets_sync(keys: List[str], values: Tensor, offsets: Tensor, weights: Optional[Tensor] = None, stride: Optional[int] = None, stride_per_key_per_rank: Optional[List[List[int]]] = None, inverse_indices: Optional[Tuple[List[str], Tensor]] = None) KeyedJaggedTensor
index_per_key() Dict[str, int]
inverse_indices() Tuple[List[str], Tensor]
inverse_indices_or_none() Optional[Tuple[List[str], Tensor]]
keys() List[str]
length_per_key() List[int]
length_per_key_or_none() Optional[List[int]]
lengths() Tensor
lengths_offset_per_key() List[int]
lengths_or_none() Optional[Tensor]
offset_per_key() List[int]
offset_per_key_or_none() Optional[List[int]]
offsets() Tensor
offsets_or_none() Optional[Tensor]
permute(indices: List[int], indices_tensor: Optional[Tensor] = None, include_inverse_indices: bool = False) KeyedJaggedTensor
pin_memory() KeyedJaggedTensor
record_stream(stream: Stream) None

See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html

split(segments: List[int]) List[KeyedJaggedTensor]
stride() int
stride_per_key() List[int]
stride_per_key_per_rank() List[List[int]]
sync() KeyedJaggedTensor
to(device: device, non_blocking: bool = False, dtype: Optional[dtype] = None) KeyedJaggedTensor

Please be aware that according to https://pytorch.org/docs/stable/generated/torch.Tensor.to.html, to might return self or a copy of self. So please remember to use to with the assignment operator, for example, in = in.to(new_device).

to_dict() Dict[str, JaggedTensor]
unsync() KeyedJaggedTensor
values() Tensor
variable_stride_per_key() bool
weights() Tensor
weights_or_none() Optional[Tensor]
class torchrec.sparse.jagged_tensor.KeyedTensor(*args, **kwargs)

Bases: Pipelineable

KeyedTensor holds a concatenated list of dense tensors, each of which can be accessed by a key.

The keyed dimension can be of variable length (length_per_key). Common use cases uses include storage of pooled embeddings of different dimensions.

Implementation is torch.jit.script-able.

Parameters:
  • keys (List[str]) – list of keys.

  • length_per_key (List[int]) – length of each key along key dimension.

  • values (torch.Tensor) – dense tensor, concatenated typically along key dimension.

  • key_dim (int) – key dimension, zero indexed - defaults to 1 (typically B is 0-dimension).

Example:

# kt is KeyedTensor holding

#                         0           1           2
#     "Embedding A"    [1,1]       [1,1]        [1,1]
#     "Embedding B"    [2,1,2]     [2,1,2]      [2,1,2]
#     "Embedding C"    [3,1,2,3]   [3,1,2,3]    [3,1,2,3]

tensor_list = [
    torch.tensor([[1,1]] * 3),
    torch.tensor([[2,1,2]] * 3),
    torch.tensor([[3,1,2,3]] * 3),
]

keys = ["Embedding A", "Embedding B", "Embedding C"]

kt = KeyedTensor.from_tensor_list(keys, tensor_list)

kt.values()
    # tensor(
    #     [
    #         [1, 1, 2, 1, 2, 3, 1, 2, 3],
    #         [1, 1, 2, 1, 2, 3, 1, 2, 3],
    #         [1, 1, 2, 1, 2, 3, 1, 2, 3],
    #     ]
    # )

kt["Embedding B"]
    # tensor([[2, 1, 2], [2, 1, 2], [2, 1, 2]])
static from_tensor_list(keys: List[str], tensors: List[Tensor], key_dim: int = 1, cat_dim: int = 1) KeyedTensor
key_dim() int
keys() List[str]
length_per_key() List[int]
offset_per_key() List[int]
record_stream(stream: Stream) None

See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html

static regroup(keyed_tensors: List[KeyedTensor], groups: List[List[str]]) List[Tensor]
static regroup_as_dict(keyed_tensors: List[KeyedTensor], groups: List[List[str]], keys: List[str]) Dict[str, Tensor]
to(device: device, non_blocking: bool = False) KeyedTensor

Please be aware that according to https://pytorch.org/docs/stable/generated/torch.Tensor.to.html, to might return self or a copy of self. So please remember to use to with the assignment operator, for example, in = in.to(new_device).

to_dict() Dict[str, Tensor]
values() Tensor
torchrec.sparse.jagged_tensor.flatten_kjt_list(kjt_arr: List[KeyedJaggedTensor]) Tuple[List[Optional[Tensor]], List[List[str]]]
torchrec.sparse.jagged_tensor.is_non_strict_exporting() bool
torchrec.sparse.jagged_tensor.jt_is_equal(jt_1: JaggedTensor, jt_2: JaggedTensor) bool

This function checks if two JaggedTensors are equal by comparing their internal representations. The comparison is done by comparing the values of the internal representations themselves. For optional fields, None values are treated as equal.

Parameters:
Returns:

True if both JaggedTensors have the same values

Return type:

bool

torchrec.sparse.jagged_tensor.kjt_is_equal(kjt_1: KeyedJaggedTensor, kjt_2: KeyedJaggedTensor) bool

This function checks if two KeyedJaggedTensors are equal by comparing their internal representations. The comparison is done by comparing the values of the internal representations themselves. For optional fields, None values are treated as equal. We compare the keys by ensuring that they have the same length and that the corresponding keys are the same order and same values.

Parameters:
Returns:

True if both KeyedJaggedTensors have the same values

Return type:

bool

torchrec.sparse.jagged_tensor.unflatten_kjt_list(values: List[Optional[Tensor]], contexts: List[List[str]]) List[KeyedJaggedTensor]

Module contents

Torchrec Jagged Tensors

It has 3 classes: JaggedTensor, KeyedJaggedTensor, KeyedTensor.

JaggedTensor

It represents an (optionally weighted) jagged tensor. A JaggedTensor is a tensor with a jagged dimension which is dimension whose slices may be of different lengths. See KeyedJaggedTensor docstring for full example and further information.

KeyedJaggedTensor

KeyedJaggedTensor has additional “Key” information. Keyed on first dimesion, and jagged on last dimension. Please refer to KeyedJaggedTensor docstring for full example and further information.

KeyedTensor

KeyedTensor holds a concatenated list of dense tensors each of which can be accessed by a key. Keyed dimension can be variable length (length_per_key). Common use cases uses include storage of pooled embeddings of different dimensions. Please refer to KeyedTensor docstring for full example and further information.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources