Shortcuts

Source code for torchvision.tv_tensors._torch_function_helpers

import torch

_TORCHFUNCTION_SUBCLASS = False


class _ReturnTypeCM:
    def __init__(self, to_restore):
        self.to_restore = to_restore

    def __enter__(self):
        return self

    def __exit__(self, *args):
        global _TORCHFUNCTION_SUBCLASS
        _TORCHFUNCTION_SUBCLASS = self.to_restore


[docs]def set_return_type(return_type: str): """[BETA] Set the return type of torch operations on :class:`~torchvision.tv_tensors.TVTensor`. This only affects the behaviour of torch operations. It has no effect on ``torchvision`` transforms or functionals, which will always return as output the same type that was passed as input. .. warning:: We recommend using :class:`~torchvision.transforms.v2.ToPureTensor` at the end of your transform pipelines if you use ``set_return_type("TVTensor")``. This will avoid the ``__torch_function__`` overhead in the models ``forward()``. Can be used as a global flag for the entire program: .. code:: python img = tv_tensors.Image(torch.rand(3, 5, 5)) img + 2 # This is a pure Tensor (default behaviour) set_return_type("TVTensor") img + 2 # This is an Image or as a context manager to restrict the scope: .. code:: python img = tv_tensors.Image(torch.rand(3, 5, 5)) img + 2 # This is a pure Tensor with set_return_type("TVTensor"): img + 2 # This is an Image img + 2 # This is a pure Tensor Args: return_type (str): Can be "TVTensor" or "Tensor" (case-insensitive). Default is "Tensor" (i.e. pure :class:`torch.Tensor`). """ global _TORCHFUNCTION_SUBCLASS to_restore = _TORCHFUNCTION_SUBCLASS try: _TORCHFUNCTION_SUBCLASS = {"tensor": False, "tvtensor": True}[return_type.lower()] except KeyError: raise ValueError(f"return_type must be 'TVTensor' or 'Tensor', got {return_type}") from None return _ReturnTypeCM(to_restore)
def _must_return_subclass(): return _TORCHFUNCTION_SUBCLASS # For those ops we always want to preserve the original subclass instead of returning a pure Tensor _FORCE_TORCHFUNCTION_SUBCLASS = {torch.Tensor.clone, torch.Tensor.to, torch.Tensor.detach, torch.Tensor.requires_grad_}

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources