Shortcuts

isin

class tensordict.utils.isin(input: TensorDictBase, reference: TensorDictBase, key: NestedKey, dim: int = 0)

Tests if each element of key in input dim is also present in the reference.

This function returns a boolean tensor of length input.batch_size[dim] that is True for elements in the entry key that are also present in the reference. This function assumes that both input and reference have the same batch size and contain the specified entry, otherwise an error will be raised.

Parameters:
  • input (TensorDictBase) – Input TensorDict.

  • reference (TensorDictBase) – Target TensorDict against which to test.

  • key (Nestedkey) – The key to test.

  • dim (int, optional) – The dimension along which to test. Defaults to 0.

Returns:

A boolean tensor of length input.batch_size[dim] that is True for elements in

the input key tensor that are also present in the reference.

Return type:

out (Tensor)

Examples

>>> td = TensorDict(
...     {
...         "tensor1": torch.tensor([[1, 2, 3], [4, 5, 6], [1, 2, 3], [7, 8, 9]]),
...         "tensor2": torch.tensor([[10, 20], [30, 40], [40, 50], [50, 60]]),
...     },
...     batch_size=[4],
... )
>>> td_ref = TensorDict(
...     {
...         "tensor1": torch.tensor([[1, 2, 3], [4, 5, 6], [10, 11, 12]]),
...         "tensor2": torch.tensor([[10, 20], [30, 40], [50, 60]]),
...     },
...     batch_size=[3],
... )
>>> in_reference = isin(td, td_ref, key="tensor1")
>>> expected_in_reference = torch.tensor([True, True, True, False])
>>> torch.testing.assert_close(in_reference, expected_in_reference)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources