is_tensor_collection¶
- class tensordict.is_tensor_collection(datatype: Union[type, Any])¶
Checks if a data object or a type is a tensor container from the tensordict lib.
- Returns:
True
if the input is a TensorDictBase subclass, a tensorclass or an istance of these.False
otherwise.
Examples
>>> is_tensor_collection(TensorDictBase) # True >>> is_tensor_collection(TensorDict()) # True >>> @tensorclass ... class MyClass: ... pass ... >>> is_tensor_collection(MyClass) # True >>> is_tensor_collection(MyClass(batch_size=[])) # True