Source code for torchvision.transforms.v2.functional._utils
importfunctoolsfromtypingimportAny,Callable,Dict,List,Optional,Sequence,Type,Unionimporttorchfromtorchvisionimporttv_tensors_FillType=Union[int,float,Sequence[int],Sequence[float],None]_FillTypeJIT=Optional[List[float]]defis_pure_tensor(inpt:Any)->bool:returnisinstance(inpt,torch.Tensor)andnotisinstance(inpt,tv_tensors.TVTensor)# {functional: {input_type: type_specific_kernel}}_KERNEL_REGISTRY:Dict[Callable,Dict[Type,Callable]]={}def_kernel_tv_tensor_wrapper(kernel):@functools.wraps(kernel)defwrapper(inpt,*args,**kwargs):# If you're wondering whether we could / should get rid of this wrapper,# the answer is no: we want to pass pure Tensors to avoid the overhead# of the __torch_function__ machinery. Note that this is always valid,# regardless of whether we override __torch_function__ in our base class# or not.# Also, even if we didn't call `as_subclass` here, we would still need# this wrapper to call wrap(), because the TVTensor type would be# lost after the first operation due to our own __torch_function__# logic.output=kernel(inpt.as_subclass(torch.Tensor),*args,**kwargs)returntv_tensors.wrap(output,like=inpt)returnwrapperdef_register_kernel_internal(functional,input_type,*,tv_tensor_wrapper=True):registry=_KERNEL_REGISTRY.setdefault(functional,{})ifinput_typeinregistry:raiseValueError(f"Functional {functional} already has a kernel registered for type {input_type}.")defdecorator(kernel):registry[input_type]=(_kernel_tv_tensor_wrapper(kernel)ifissubclass(input_type,tv_tensors.TVTensor)andtv_tensor_wrapperelsekernel)returnkernelreturndecoratordef_name_to_functional(name):importtorchvision.transforms.v2.functional# noqatry:returngetattr(torchvision.transforms.v2.functional,name)exceptAttributeError:raiseValueError(f"Could not find functional with name '{name}' in torchvision.transforms.v2.functional.")fromNone_BUILTIN_DATAPOINT_TYPES={objforobjintv_tensors.__dict__.values()ifisinstance(obj,type)andissubclass(obj,tv_tensors.TVTensor)}
[docs]defregister_kernel(functional,tv_tensor_cls):"""Decorate a kernel to register it for a functional and a (custom) tv_tensor type. See :ref:`sphx_glr_auto_examples_transforms_plot_custom_tv_tensors.py` for usage details. """ifisinstance(functional,str):functional=_name_to_functional(name=functional)elifnot(callable(functional)andgetattr(functional,"__module__","").startswith("torchvision.transforms.v2.functional")):raiseValueError(f"Kernels can only be registered on functionals from the torchvision.transforms.v2.functional namespace, "f"but got {functional}.")ifnot(isinstance(tv_tensor_cls,type)andissubclass(tv_tensor_cls,tv_tensors.TVTensor)):raiseValueError(f"Kernels can only be registered for subclasses of torchvision.tv_tensors.TVTensor, "f"but got {tv_tensor_cls}.")iftv_tensor_clsin_BUILTIN_DATAPOINT_TYPES:raiseValueError(f"Kernels cannot be registered for the builtin tv_tensor classes, but got {tv_tensor_cls}")return_register_kernel_internal(functional,tv_tensor_cls,tv_tensor_wrapper=False)
def_get_kernel(functional,input_type,*,allow_passthrough=False):registry=_KERNEL_REGISTRY.get(functional)ifnotregistry:raiseValueError(f"No kernel registered for functional {functional.__name__}.")forclsininput_type.__mro__:ifclsinregistry:returnregistry[cls]elifclsistv_tensors.TVTensor:# We don't want user-defined tv_tensors to dispatch to the pure Tensor kernels, so we explicit stop the# MRO traversal before hitting torch.Tensor. We can even stop at tv_tensors.TVTensor, since we don't# allow kernels to be registered for tv_tensors.TVTensor anyway.breakifallow_passthrough:returnlambdainpt,*args,**kwargs:inptraiseTypeError(f"Functional F.{functional.__name__} supports inputs of type {registry.keys()}, "f"but got {input_type} instead.")# This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop# We could get rid of this by letting _register_kernel_internal take arbitrary functionals rather than wrap_kernel: booldef_register_five_ten_crop_kernel_internal(functional,input_type):registry=_KERNEL_REGISTRY.setdefault(functional,{})ifinput_typeinregistry:raiseTypeError(f"Functional '{functional}' already has a kernel registered for type '{input_type}'.")defwrap(kernel):@functools.wraps(kernel)defwrapper(inpt,*args,**kwargs):output=kernel(inpt,*args,**kwargs)container_type=type(output)returncontainer_type(tv_tensors.wrap(o,like=inpt)foroinoutput)returnwrapperdefdecorator(kernel):registry[input_type]=wrap(kernel)ifissubclass(input_type,tv_tensors.TVTensor)elsekernelreturnkernelreturndecorator
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.