Shortcuts

Data Types

TorchRec contains data types for representing embedding, otherwise known as sparse features. Sparse features are typically indices that are meant to be fed into embedding tables. For a given batch, the number of embedding lookup indices are variable. Therefore, there is a need for a jagged dimension to represent the variable amount of embedding lookup indices for a batch.

This section covers the classes for the 3 TorchRec data types for representing sparse features: JaggedTensor, KeyedJaggedTensor, and KeyedTensor.

class torchrec.sparse.jagged_tensor.JaggedTensor(*args, **kwargs)

Represents an (optionally weighted) jagged tensor.

A JaggedTensor is a tensor with a jagged dimension which is dimension whose slices may be of different lengths. See KeyedJaggedTensor for full example.

Implementation is torch.jit.script-able.

Note

We will NOT do input validation as it’s expensive, you should always pass in the valid lengths, offsets, etc.

Parameters:
  • values (torch.Tensor) – values tensor in dense representation.

  • weights (Optional[torch.Tensor]) – if values have weights. Tensor with same shape as values.

  • lengths (Optional[torch.Tensor]) – jagged slices, represented as lengths.

  • offsets (Optional[torch.Tensor]) – jagged slices, represented as cumulative offsets.

device() device

Get JaggedTensor device.

Returns:

the device of the values tensor.

Return type:

torch.device

static empty(is_weighted: bool = False, device: Optional[device] = None, values_dtype: Optional[dtype] = None, weights_dtype: Optional[dtype] = None, lengths_dtype: dtype = torch.int32) JaggedTensor

Constructs an empty JaggedTensor.

Parameters:
  • is_weighted (bool) – whether the JaggedTensor has weights.

  • device (Optional[torch.device]) – device for JaggedTensor.

  • values_dtype (Optional[torch.dtype]) – dtype for values.

  • weights_dtype (Optional[torch.dtype]) – dtype for weights.

  • lengths_dtype (torch.dtype) – dtype for lengths.

Returns:

empty JaggedTensor.

Return type:

JaggedTensor

static from_dense(values: List[Tensor], weights: Optional[List[Tensor]] = None) JaggedTensor

Constructs JaggedTensor from list of tensors as values, with optional weights. lengths will be computed, of shape (B,), where B is len(values) which represents the batch size.

Parameters:
  • values (List[torch.Tensor]) – a list of tensors for dense representation

  • weights (Optional[List[torch.Tensor]]) – if values have weights, tensor with the same shape as values.

Returns:

JaggedTensor created from 2D dense tensor.

Return type:

JaggedTensor

Example:

values = [
    torch.Tensor([1.0]),
    torch.Tensor(),
    torch.Tensor([7.0, 8.0]),
    torch.Tensor([10.0, 11.0, 12.0]),
]
weights = [
    torch.Tensor([1.0]),
    torch.Tensor(),
    torch.Tensor([7.0, 8.0]),
    torch.Tensor([10.0, 11.0, 12.0]),
]
j1 = JaggedTensor.from_dense(
    values=values,
    weights=weights,
)

# j1 = [[1.0], [], [7.0, 8.0], [10.0, 11.0, 12.0]]
static from_dense_lengths(values: Tensor, lengths: Tensor, weights: Optional[Tensor] = None) JaggedTensor

Constructs JaggedTensor from values and lengths tensors, with optional weights. Note that lengths is still of shape (B,), where B is the batch size.

Parameters:
  • values (torch.Tensor) – dense representation of values.

  • lengths (torch.Tensor) – jagged slices, represented as lengths.

  • weights (Optional[torch.Tensor]) – if values have weights, tensor with the same shape as values.

Returns:

JaggedTensor created from 2D dense tensor.

Return type:

JaggedTensor

lengths() Tensor

Get JaggedTensor lengths. If not computed, compute it from offsets.

Returns:

the lengths tensor.

Return type:

torch.Tensor

lengths_or_none() Optional[Tensor]

Get JaggedTensor lengths. If not computed, return None.

Returns:

the lengths tensor.

Return type:

Optional[torch.Tensor]

offsets() Tensor

Get JaggedTensor offsets. If not computed, compute it from lengths.

Returns:

the offsets tensor.

Return type:

torch.Tensor

offsets_or_none() Optional[Tensor]

Get JaggedTensor offsets. If not computed, return None.

Returns:

the offsets tensor.

Return type:

Optional[torch.Tensor]

record_stream(stream: Stream) None

See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html

to(device: device, non_blocking: bool = False) JaggedTensor

Move the JaggedTensor to the specified device.

Parameters:
  • device (torch.device) – the device to move to.

  • non_blocking (bool) – whether to perform the copy asynchronously.

Returns:

the moved JaggedTensor.

Return type:

JaggedTensor

to_dense() List[Tensor]

Constructs a dense-representation of the JT’s values.

Returns:

list of tensors.

Return type:

List[torch.Tensor]

Example:

values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8])
jt = JaggedTensor(values=values, offsets=offsets)

values_list = jt.to_dense()

# values_list = [
#     torch.tensor([1.0, 2.0]),
#     torch.tensor([]),
#     torch.tensor([3.0]),
#     torch.tensor([4.0]),
#     torch.tensor([5.0]),
#     torch.tensor([6.0, 7.0, 8.0]),
# ]
to_dense_weights() Optional[List[Tensor]]

Constructs a dense-representation of the JT’s weights.

Returns:

list of tensors, None if no weights.

Return type:

Optional[List[torch.Tensor]]

Example:

values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
weights = torch.Tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8])
offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8])
jt = JaggedTensor(values=values, weights=weights, offsets=offsets)

weights_list = jt.to_dense_weights()

# weights_list = [
#     torch.tensor([0.1, 0.2]),
#     torch.tensor([]),
#     torch.tensor([0.3]),
#     torch.tensor([0.4]),
#     torch.tensor([0.5]),
#     torch.tensor([0.6, 0.7, 0.8]),
# ]
to_padded_dense(desired_length: Optional[int] = None, padding_value: float = 0.0) Tensor

Constructs a 2D dense tensor from the JT’s values of shape (B, N,).

Note that B is the length of self.lengths() and N is the longest feature length or desired_length.

If desired_length > length we will pad with padding_value, otherwise we will select the last value at desired_length.

Parameters:
  • desired_length (int) – the length of the tensor.

  • padding_value (float) – padding value if we need to pad.

Returns:

2d dense tensor.

Return type:

torch.Tensor

Example:

values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8])
jt = JaggedTensor(values=values, offsets=offsets)

dt = jt.to_padded_dense(
    desired_length=2,
    padding_value=10.0,
)

# dt = [
#     [1.0, 2.0],
#     [10.0, 10.0],
#     [3.0, 10.0],
#     [4.0, 10.0],
#     [5.0, 10.0],
#     [6.0, 7.0],
# ]
to_padded_dense_weights(desired_length: Optional[int] = None, padding_value: float = 0.0) Optional[Tensor]

Constructs a 2D dense tensor from the JT’s weights of shape (B, N,).

Note that B (batch size) is the length of self.lengths() and N is the longest feature length or desired_length.

If desired_length > length we will pad with padding_value, otherwise we will select the last value at desired_length.

Like to_padded_dense but for the JT’s weights instead of values.

Parameters:
  • desired_length (int) – the length of the tensor.

  • padding_value (float) – padding value if we need to pad.

Returns:

2d dense tensor, None if no weights.

Return type:

Optional[torch.Tensor]

Example:

values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
weights = torch.Tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8])
offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8])
jt = JaggedTensor(values=values, weights=weights, offsets=offsets)

d_wt = jt.to_padded_dense_weights(
    desired_length=2,
    padding_value=1.0,
)

# d_wt = [
#     [0.1, 0.2],
#     [1.0, 1.0],
#     [0.3, 1.0],
#     [0.4, 1.0],
#     [0.5, 1.0],
#     [0.6, 0.7],
# ]
values() Tensor

Get JaggedTensor values.

Returns:

the values tensor.

Return type:

torch.Tensor

weights() Tensor

Get JaggedTensor weights. If None, throw an error.

Returns:

the weights tensor.

Return type:

torch.Tensor

weights_or_none() Optional[Tensor]

Get JaggedTensor weights. If None, return None.

Returns:

the weights tensor.

Return type:

Optional[torch.Tensor]

class torchrec.sparse.jagged_tensor.KeyedJaggedTensor(*args, **kwargs)

Represents an (optionally weighted) keyed jagged tensor.

A KeyedJaggedTensor is a tensor with a jagged dimension which is dimension whose slices may be of different lengths. Keyed on first dimension and jagged on the last dimension.

Implementation is torch.jit.script-able.

Parameters:
  • keys (List[str]) – keys to the jagged Tensor.

  • values (torch.Tensor) – values tensor in dense representation.

  • weights (Optional[torch.Tensor]) – if the values have weights. Tensor with the same shape as values.

  • lengths (Optional[torch.Tensor]) – jagged slices, represented as lengths.

  • offsets (Optional[torch.Tensor]) – jagged slices, represented as cumulative offsets.

  • stride (Optional[int]) – number of examples per batch.

  • stride_per_key_per_rank (Optional[List[List[int]]]) – batch size (number of examples) per key per rank, with the outer list representing the keys and the inner list representing the values. Each value in the inner list represents the number of examples in the batch from the rank of its index in a distributed context.

  • length_per_key (Optional[List[int]]) – start length for each key.

  • offset_per_key (Optional[List[int]]) – start offset for each key and final offset.

  • index_per_key (Optional[Dict[str, int]]) – index for each key.

  • jt_dict (Optional[Dict[str, JaggedTensor]]) – dictionary of keys to JaggedTensors. Allow ability to make to_dict() lazy/cacheable.

  • inverse_indices (Optional[Tuple[List[str], torch.Tensor]]) – inverse indices to expand deduplicated embedding output for variable stride per key.

Example:

#              0       1        2  <-- dim_1
# "Feature0"   [V0,V1] None    [V2]
# "Feature1"   [V3]    [V4]    [V5,V6,V7]
#   ^
#  dim_0

dim_0: keyed dimension (ie. `Feature0`, `Feature1`)
dim_1: optional second dimension (ie. batch size)
dim_2: The jagged dimension which has slice lengths between 0-3 in the above example

# We represent this data with following inputs:

values: torch.Tensor = [V0, V1, V2, V3, V4, V5, V6, V7]  # V == any tensor datatype
weights: torch.Tensor = [W0, W1, W2, W3, W4, W5, W6, W7]  # W == any tensor datatype
lengths: torch.Tensor = [2, 0, 1, 1, 1, 3]  # representing the jagged slice
offsets: torch.Tensor = [0, 2, 2, 3, 4, 5, 8]  # offsets from 0 for each jagged slice
keys: List[str] = ["Feature0", "Feature1"]  # correspond to each value of dim_0
index_per_key: Dict[str, int] = {"Feature0": 0, "Feature1": 1}  # index for each key
offset_per_key: List[int] = [0, 3, 8]  # start offset for each key and final offset
static concat(kjt_list: List[KeyedJaggedTensor]) KeyedJaggedTensor

Concatenates a list of KeyedJaggedTensors into a single KeyedJaggedTensor.

Parameters:

kjt_list (List[KeyedJaggedTensor]) – list of KeyedJaggedTensors to be concatenated.

Returns:

concatenated KeyedJaggedTensor.

Return type:

KeyedJaggedTensor

device() device

Returns the device of the KeyedJaggedTensor.

Returns:

device of the KeyedJaggedTensor.

Return type:

torch.device

static empty(is_weighted: bool = False, device: Optional[device] = None, values_dtype: Optional[dtype] = None, weights_dtype: Optional[dtype] = None, lengths_dtype: dtype = torch.int32) KeyedJaggedTensor

Constructs an empty KeyedJaggedTensor.

Parameters:
  • is_weighted (bool) – whether the KeyedJaggedTensor is weighted or not.

  • device (Optional[torch.device]) – device on which the KeyedJaggedTensor will be placed.

  • values_dtype (Optional[torch.dtype]) – dtype of the values tensor.

  • weights_dtype (Optional[torch.dtype]) – dtype of the weights tensor.

  • lengths_dtype (torch.dtype) – dtype of the lengths tensor.

Returns:

empty KeyedJaggedTensor.

Return type:

KeyedJaggedTensor

static empty_like(kjt: KeyedJaggedTensor) KeyedJaggedTensor

Constructs an empty KeyedJaggedTensor with the same device and dtypes as the input KeyedJaggedTensor.

Parameters:

kjt (KeyedJaggedTensor) – input KeyedJaggedTensor.

Returns:

empty KeyedJaggedTensor.

Return type:

KeyedJaggedTensor

static from_jt_dict(jt_dict: Dict[str, JaggedTensor]) KeyedJaggedTensor

Constructs a KeyedJaggedTensor from a dictionary of JaggedTensors. Automatically calls kjt.sync() on newly created KJT.

Note

This function will ONLY work if the JaggedTensors all have the same “implicit” batch_size dimension.

Basically, we can visualize JaggedTensors as 2-D tensors of the format of [batch_size x variable_feature_dim]. In the case, we have some batch without a feature value, the input JaggedTensor could just not include any values.

But KeyedJaggedTensor (by default) typically pad “None” so that all the JaggedTensors stored in the KeyedJaggedTensor have the same batch_size dimension. That is, in the case, the JaggedTensor input didn’t automatically pad for the empty batches, this function would error / not work.

Consider the visualization of the following KeyedJaggedTensor: # 0 1 2 <– dim_1 # “Feature0” [V0,V1] None [V2] # “Feature1” [V3] [V4] [V5,V6,V7] # ^ # dim_0

Now if the input jt_dict = {

# “Feature0” [V0,V1] [V2] # “Feature1” [V3] [V4] [V5,V6,V7]

} and the “None” is left out from each JaggedTensor, then this function would fail as we would not correctly be able to pad “None” as it does not technically know the correct batch / place to pad within the JaggedTensor.

Essentially, the lengths Tensor inferred by this function would be [2, 1, 1, 1, 3] indicating variable batch_size dim_1 violates the existing assumption / precondition that KeyedJaggedTensor’s should have fixed batch_size dimension.

Parameters:

jt_dict (Dict[str, JaggedTensor]) – dictionary of JaggedTensors.

Returns:

constructed KeyedJaggedTensor.

Return type:

KeyedJaggedTensor

static from_lengths_sync(keys: List[str], values: Tensor, lengths: Tensor, weights: Optional[Tensor] = None, stride: Optional[int] = None, stride_per_key_per_rank: Optional[List[List[int]]] = None, inverse_indices: Optional[Tuple[List[str], Tensor]] = None) KeyedJaggedTensor

Constructs a KeyedJaggedTensor from a list of keys, lengths, and offsets. Same as from_offsets_sync except lengths are used instead of offsets.

Parameters:
  • keys (List[str]) – list of keys.

  • values (torch.Tensor) – values tensor in dense representation.

  • lengths (torch.Tensor) – jagged slices, represented as lengths.

  • weights (Optional[torch.Tensor]) – if the values have weights. Tensor with the same shape as values.

  • stride (Optional[int]) – number of examples per batch.

  • stride_per_key_per_rank (Optional[List[List[int]]]) – batch size (number of examples) per key per rank, with the outer list representing the keys and the inner list representing the values.

  • inverse_indices (Optional[Tuple[List[str], torch.Tensor]]) – inverse indices to expand deduplicated embedding output for variable stride per key.

Returns:

constructed KeyedJaggedTensor.

Return type:

KeyedJaggedTensor

static from_offsets_sync(keys: List[str], values: Tensor, offsets: Tensor, weights: Optional[Tensor] = None, stride: Optional[int] = None, stride_per_key_per_rank: Optional[List[List[int]]] = None, inverse_indices: Optional[Tuple[List[str], Tensor]] = None) KeyedJaggedTensor

Constructs a KeyedJaggedTensor from a list of keys, values, and offsets.

Parameters:
  • keys (List[str]) – list of keys.

  • values (torch.Tensor) – values tensor in dense representation.

  • offsets (torch.Tensor) – jagged slices, represented as cumulative offsets.

  • weights (Optional[torch.Tensor]) – if the values have weights. Tensor with the same shape as values.

  • stride (Optional[int]) – number of examples per batch.

  • stride_per_key_per_rank (Optional[List[List[int]]]) – batch size (number of examples) per key per rank, with the outer list representing the keys and the inner list representing the values.

  • inverse_indices (Optional[Tuple[List[str], torch.Tensor]]) – inverse indices to expand deduplicated embedding output for variable stride per key.

Returns:

constructed KeyedJaggedTensor.

Return type:

KeyedJaggedTensor

index_per_key() Dict[str, int]

Returns the index per key of the KeyedJaggedTensor.

Returns:

index per key of the KeyedJaggedTensor.

Return type:

Dict[str, int]

inverse_indices() Tuple[List[str], Tensor]

Returns the inverse indices of the KeyedJaggedTensor. If inverse indices are None, this will throw an error.

Returns:

inverse indices of the KeyedJaggedTensor.

Return type:

Tuple[List[str], torch.Tensor]

inverse_indices_or_none() Optional[Tuple[List[str], Tensor]]

Returns the inverse indices of the KeyedJaggedTensor or None if they don’t exist.

Returns:

inverse indices of the KeyedJaggedTensor.

Return type:

Optional[Tuple[List[str], torch.Tensor]]

keys() List[str]

Returns the keys of the KeyedJaggedTensor.

Returns:

keys of the KeyedJaggedTensor.

Return type:

List[str]

length_per_key() List[int]

Returns the length per key of the KeyedJaggedTensor. If length per key is None, this will compute it.

Returns:

length per key of the KeyedJaggedTensor.

Return type:

List[int]

length_per_key_or_none() Optional[List[int]]

Returns the length per key of the KeyedJaggedTensor or None if it hasn’t been computed.

Returns:

length per key of the KeyedJaggedTensor.

Return type:

List[int]

lengths() Tensor

Returns the lengths of the KeyedJaggedTensor. If the lengths are not computed yet, it will compute them.

Returns:

lengths of the KeyedJaggedTensor.

Return type:

torch.Tensor

lengths_offset_per_key() List[int]

Returns the lengths offset per key of the KeyedJaggedTensor. If lengths offset per key is None, this will compute it.

Returns:

lengths offset per key of the KeyedJaggedTensor.

Return type:

List[int]

lengths_or_none() Optional[Tensor]

Returns the lengths of the KeyedJaggedTensor or None if they are not computed yet.

Returns:

lengths of the KeyedJaggedTensor.

Return type:

torch.Tensor

offset_per_key() List[int]

Returns the offset per key of the KeyedJaggedTensor. If offset per key is None, this will compute it.

Returns:

offset per key of the KeyedJaggedTensor.

Return type:

List[int]

offset_per_key_or_none() Optional[List[int]]

Returns the offset per key of the KeyedJaggedTensor or None if it hasn’t been computed.

Returns:

offset per key of the KeyedJaggedTensor.

Return type:

List[int]

offsets() Tensor

Returns the offsets of the KeyedJaggedTensor. If the offsets are not computed yet, it will compute them.

Returns:

offsets of the KeyedJaggedTensor.

Return type:

torch.Tensor

offsets_or_none() Optional[Tensor]

Returns the offsets of the KeyedJaggedTensor or None if they are not computed yet.

Returns:

offsets of the KeyedJaggedTensor.

Return type:

torch.Tensor

permute(indices: List[int], indices_tensor: Optional[Tensor] = None) KeyedJaggedTensor

Permutes the KeyedJaggedTensor.

Parameters:
  • indices (List[int]) – list of indices.

  • indices_tensor (Optional[torch.Tensor]) – tensor of indices.

Returns:

permuted KeyedJaggedTensor.

Return type:

KeyedJaggedTensor

record_stream(stream: Stream) None

See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html

split(segments: List[int]) List[KeyedJaggedTensor]

Splits the KeyedJaggedTensor into a list of KeyedJaggedTensor.

Parameters:

segments (List[int]) – list of segments.

Returns:

list of KeyedJaggedTensor.

Return type:

List[KeyedJaggedTensor]

stride() int

Returns the stride of the KeyedJaggedTensor. If stride is None, this will compute it.

Returns:

stride of the KeyedJaggedTensor.

Return type:

int

stride_per_key() List[int]

Returns the stride per key of the KeyedJaggedTensor. If stride per key is None, this will compute it.

Returns:

stride per key of the KeyedJaggedTensor.

Return type:

List[int]

stride_per_key_per_rank() List[List[int]]

Returns the stride per key per rank of the KeyedJaggedTensor.

Returns:

stride per key per rank of the KeyedJaggedTensor.

Return type:

List[List[int]]

sync() KeyedJaggedTensor

Synchronizes the KeyedJaggedTensor by computing the offset_per_key and length_per_key.

Returns:

synced KeyedJaggedTensor.

Return type:

KeyedJaggedTensor

to(device: device, non_blocking: bool = False, dtype: Optional[dtype] = None) KeyedJaggedTensor

Returns a copy of KeyedJaggedTensor in the specified device and dtype.

Parameters:
  • device (torch.device) – the desired device of the copy.

  • non_blocking (bool) – whether to copy the tensors in a non-blocking fashion.

  • dtype (Optional[torch.dtype]) – the desired data type of the copy.

Returns:

the copied KeyedJaggedTensor.

Return type:

KeyedJaggedTensor

to_dict() Dict[str, JaggedTensor]

Returns a dictionary of JaggedTensor for each key. Will cache result in self._jt_dict.

Returns:

dictionary of JaggedTensor for each key.

Return type:

Dict[str, JaggedTensor]

unsync() KeyedJaggedTensor

Unsyncs the KeyedJaggedTensor by clearing the offset_per_key and length_per_key.

Returns:

unsynced KeyedJaggedTensor.

Return type:

KeyedJaggedTensor

values() Tensor

Returns the values of the KeyedJaggedTensor.

Returns:

values of the KeyedJaggedTensor.

Return type:

torch.Tensor

variable_stride_per_key() bool

Returns whether the KeyedJaggedTensor has variable stride per key.

Returns:

whether the KeyedJaggedTensor has variable stride per key.

Return type:

bool

weights() Tensor

Returns the weights of the KeyedJaggedTensor. If weights is None, this will throw an error.

Returns:

weights of the KeyedJaggedTensor.

Return type:

torch.Tensor

weights_or_none() Optional[Tensor]

Returns the weights of the KeyedJaggedTensor or None if they don’t exist.

Returns:

weights of the KeyedJaggedTensor.

Return type:

torch.Tensor

class torchrec.sparse.jagged_tensor.KeyedTensor(*args, **kwargs)

KeyedTensor holds a concatenated list of dense tensors, each of which can be accessed by a key.

The keyed dimension can be of variable length (length_per_key). Common use cases uses include storage of pooled embeddings of different dimensions.

Implementation is torch.jit.script-able.

Parameters:
  • keys (List[str]) – list of keys.

  • length_per_key (List[int]) – length of each key along key dimension.

  • values (torch.Tensor) – dense tensor, concatenated typically along key dimension.

  • key_dim (int) – key dimension, zero indexed - defaults to 1 (typically B is 0-dimension).

Example:

# kt is KeyedTensor holding

#                         0           1           2
#     "Embedding A"    [1,1]       [1,1]        [1,1]
#     "Embedding B"    [2,1,2]     [2,1,2]      [2,1,2]
#     "Embedding C"    [3,1,2,3]   [3,1,2,3]    [3,1,2,3]

tensor_list = [
    torch.tensor([[1,1]] * 3),
    torch.tensor([[2,1,2]] * 3),
    torch.tensor([[3,1,2,3]] * 3),
]

keys = ["Embedding A", "Embedding B", "Embedding C"]

kt = KeyedTensor.from_tensor_list(keys, tensor_list)

kt.values()
# torch.Tensor(
#     [
#         [1, 1, 2, 1, 2, 3, 1, 2, 3],
#         [1, 1, 2, 1, 2, 3, 1, 2, 3],
#         [1, 1, 2, 1, 2, 3, 1, 2, 3],
#     ]
# )

kt["Embedding B"]
# torch.Tensor([[2, 1, 2], [2, 1, 2], [2, 1, 2]])
device() device
Returns:

device of the values tensor.

Return type:

torch.device

static from_tensor_list(keys: List[str], tensors: List[Tensor], key_dim: int = 1, cat_dim: int = 1) KeyedTensor

Create a KeyedTensor from a list of tensors. The tensors are concatenated along the cat_dim. The keys are used to index the tensors.

Parameters:
  • keys (List[str]) – list of keys.

  • tensors (List[torch.Tensor]) – list of tensors.

  • key_dim (int) – key dimension, zero indexed - defaults to 1 (typically B is 0-dimension).

  • cat_dim (int) – dimension along which to concatenate the tensors - defaults

Returns:

keyed tensor.

Return type:

KeyedTensor

key_dim() int
Returns:

key dimension, zero indexed - typically B is 0-dimension.

Return type:

int

keys() List[str]
Returns:

list of keys.

Return type:

List[str]

length_per_key() List[int]
Returns:

length of each key along key dimension.

Return type:

List[int]

offset_per_key() List[int]

Get the offset of each key along key dimension. Compute and cache if not already computed.

Returns:

offset of each key along key dimension.

Return type:

List[int]

record_stream(stream: Stream) None

See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html

static regroup(keyed_tensors: List[KeyedTensor], groups: List[List[str]]) List[Tensor]

Regroup a list of KeyedTensors into a list of tensors.

Parameters:
  • keyed_tensors (List[KeyedTensor]) – list of KeyedTensors.

  • groups (List[List[str]]) – list of groups of keys.

Returns:

list of tensors.

Return type:

List[torch.Tensor]

static regroup_as_dict(keyed_tensors: List[KeyedTensor], groups: List[List[str]], keys: List[str]) Dict[str, Tensor]

Regroup a list of KeyedTensors into a dictionary of tensors.

Parameters:
  • keyed_tensors (List[KeyedTensor]) – list of KeyedTensors.

  • groups (List[List[str]]) – list of groups of keys.

  • keys (List[str]) – list of keys.

Returns:

dictionary of tensors.

Return type:

Dict[str, torch.Tensor]

to(device: device, non_blocking: bool = False) KeyedTensor

Moves the values tensor to the specified device.

Parameters:
  • device (torch.device) – device to move the values tensor to.

  • non_blocking (bool) – whether to perform the operation asynchronously (default: False).

Returns:

keyed tensor with values tensor moved to the specified device.

Return type:

KeyedTensor

to_dict() Dict[str, Tensor]
Returns:

dictionary of tensors keyed by the keys.

Return type:

Dict[str, torch.Tensor]

values() Tensor

Get the values tensor.

Returns:

dense tensor, concatenated typically along key dimension.

Return type:

torch.Tensor

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources