[docs]classTVTensor(torch.Tensor):"""Base class for all TVTensors. You probably don't want to use this class unless you're defining your own custom TVTensors. See :ref:`sphx_glr_auto_examples_transforms_plot_custom_tv_tensors.py` for details. """@staticmethoddef_to_tensor(data:Any,dtype:Optional[torch.dtype]=None,device:Optional[Union[torch.device,str,int]]=None,requires_grad:Optional[bool]=None,)->torch.Tensor:ifrequires_gradisNone:requires_grad=data.requires_gradifisinstance(data,torch.Tensor)elseFalsereturntorch.as_tensor(data,dtype=dtype,device=device).requires_grad_(requires_grad)@classmethoddef_wrap_output(cls,output:torch.Tensor,args:Sequence[Any]=(),kwargs:Optional[Mapping[str,Any]]=None,)->torch.Tensor:# Same as torch._tensor._convertifisinstance(output,torch.Tensor)andnotisinstance(output,cls):output=output.as_subclass(cls)ifisinstance(output,(tuple,list)):# Also handles things like namedtuplesoutput=type(output)(cls._wrap_output(part,args,kwargs)forpartinoutput)returnoutput@classmethoddef__torch_function__(cls,func:Callable[...,torch.Tensor],types:Tuple[Type[torch.Tensor],...],args:Sequence[Any]=(),kwargs:Optional[Mapping[str,Any]]=None,)->torch.Tensor:"""For general information about how the __torch_function__ protocol works, see https://pytorch.org/docs/stable/notes/extending.html#extending-torch TL;DR: Every time a PyTorch operator is called, it goes through the inputs and looks for the ``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the ``args`` and ``kwargs`` of the original call. Why do we override this? Because the base implementation in torch.Tensor would preserve the TVTensor type of the output. In our case, we want to return pure tensors instead (with a few exceptions). Refer to the "TVTensors FAQ" gallery example for a rationale of this behaviour (TL;DR: perf + no silver bullet). Our implementation below is very similar to the base implementation in ``torch.Tensor`` - go check it out. """ifnotall(issubclass(cls,t)fortintypes):returnNotImplemented# Like in the base Tensor.__torch_function__ implementation, it's easier to always use# DisableTorchFunctionSubclass and then manually re-wrap the output if necessarywithDisableTorchFunctionSubclass():output=func(*args,**kwargsordict())must_return_subclass=_must_return_subclass()ifmust_return_subclassor(funcin_FORCE_TORCHFUNCTION_SUBCLASSandisinstance(args[0],cls)):# If you're wondering why we need the `isinstance(args[0], cls)` check, remove it and see what fails# in test_to_tv_tensor_reference().# The __torch_function__ protocol will invoke the __torch_function__ method on *all* types involved in# the computation by walking the MRO upwards. For example,# `out = a_pure_tensor.to(an_image)` will invoke `Image.__torch_function__` with# `args = (a_pure_tensor, an_image)` first. Without this guard, `out` would# be wrapped into an `Image`.returncls._wrap_output(output,args,kwargs)ifnotmust_return_subclassandisinstance(output,cls):# DisableTorchFunctionSubclass is ignored by inplace ops like `.add_(...)`,# so for those, the output is still a TVTensor. Thus, we need to manually unwrap.returnoutput.as_subclass(torch.Tensor)returnoutputdef_make_repr(self,**kwargs:Any)->str:# This is a poor man's implementation of the proposal in https://github.com/pytorch/pytorch/issues/76532.# If that ever gets implemented, remove this in favor of the solution on the `torch.Tensor` class.extra_repr=", ".join(f"{key}={value}"forkey,valueinkwargs.items())returnf"{super().__repr__()[:-1]}, {extra_repr})"# Add properties for common attributes like shape, dtype, device, ndim etc# this way we return the result without passing into __torch_function__@propertydefshape(self)->_size:# type: ignore[override]withDisableTorchFunctionSubclass():returnsuper().shape@propertydefndim(self)->int:# type: ignore[override]withDisableTorchFunctionSubclass():returnsuper().ndim@propertydefdevice(self,*args:Any,**kwargs:Any)->_device:# type: ignore[override]withDisableTorchFunctionSubclass():returnsuper().device@propertydefdtype(self)->_dtype:# type: ignore[override]withDisableTorchFunctionSubclass():returnsuper().dtypedef__deepcopy__(self:D,memo:Dict[int,Any])->D:# We need to detach first, since a plain `Tensor.clone` will be part of the computation graph, which does# *not* happen for `deepcopy(Tensor)`. A side-effect from detaching is that the `Tensor.requires_grad`# attribute is cleared, so we need to refill it before we return.# Note: We don't explicitly handle deep-copying of the metadata here. The only metadata we currently have is# `BoundingBoxes.format` and `BoundingBoxes.canvas_size`, which are immutable and thus implicitly deep-copied by# `BoundingBoxes.clone()`.returnself.detach().clone().requires_grad_(self.requires_grad)# type: ignore[return-value]
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.