isin¶
- class tensordict.utils.isin(input: TensorDictBase, reference: TensorDictBase, key: NestedKey, dim: int = 0)¶
Tests if each element of
key
in inputdim
is also present in the reference.This function returns a boolean tensor of length
input.batch_size[dim]
that isTrue
for elements in the entrykey
that are also present in thereference
. This function assumes that bothinput
andreference
have the same batch size and contain the specified entry, otherwise an error will be raised.- Parameters:
input (TensorDictBase) – Input TensorDict.
reference (TensorDictBase) – Target TensorDict against which to test.
key (Nestedkey) – The key to test.
dim (int, optional) – The dimension along which to test. Defaults to
0
.
- Returns:
- A boolean tensor of length
input.batch_size[dim]
that isTrue
for elements in the
input
key
tensor that are also present in thereference
.
- A boolean tensor of length
- Return type:
out (Tensor)
Examples
>>> td = TensorDict( ... { ... "tensor1": torch.tensor([[1, 2, 3], [4, 5, 6], [1, 2, 3], [7, 8, 9]]), ... "tensor2": torch.tensor([[10, 20], [30, 40], [40, 50], [50, 60]]), ... }, ... batch_size=[4], ... ) >>> td_ref = TensorDict( ... { ... "tensor1": torch.tensor([[1, 2, 3], [4, 5, 6], [10, 11, 12]]), ... "tensor2": torch.tensor([[10, 20], [30, 40], [50, 60]]), ... }, ... batch_size=[3], ... ) >>> in_reference = isin(td, td_ref, key="tensor1") >>> expected_in_reference = torch.tensor([True, True, True, False]) >>> torch.testing.assert_close(in_reference, expected_in_reference)