"""The torch package contains data structures for multi-dimensionaltensors and defines mathematical operations over these tensors.Additionally, it provides many utilities for efficient serialization ofTensors and arbitrary types, and other useful utilities.It has a CUDA counterpart, that enables you to run your tensor computationson an NVIDIA GPU with compute capability >= 3.0."""# mypy: allow-untyped-defsimportbuiltinsimportctypesimportglobimportimportlibimportinspectimportmathimportosimportplatformimportsysimporttextwrapimportthreadingfromtypingimport(Anyas_Any,Callableas_Callable,get_originas_get_origin,Optionalas_Optional,overloadas_overload,TYPE_CHECKING,TypeVaras_TypeVar,Unionas_Union,)fromtyping_extensionsimportParamSpecas_ParamSpecifTYPE_CHECKING:from.typesimportIntLikeType# multipy/deploy is setting this import before importing torch, this is the most# reliable way we have to detect if we're running within deploy.# https://github.com/pytorch/multipy/blob/d60f34ad38c371e441fe7ffdb77a3c3dda5a5d19/multipy/runtime/interpreter/interpreter_impl.cpp#L134-L137def_running_with_deploy()->builtins.bool:returnsys.modules.get("torch._meta_registrations",None)isobjectfromtorch._utilsimport(_functionalize_syncas_sync,_import_dotted_name,classproperty,)fromtorch._utils_internalimport(get_file_path,prepare_multiprocessing_environment,USE_GLOBAL_DEPS,USE_RTLD_GLOBAL_WITH_LIBTORCH,)# TODO(torch_deploy) figure out how to freeze version.py in fbcode buildif_running_with_deploy():__version__="torch-deploy-1.8"# TODO: Remove this ugly hack when deploy typing extensions are updated to 4.10+ifnotTYPE_CHECKING:importtyping_extensions_TypeIs=typing_extensions.TypeGuardtyping_extensions.TypeIs=_TypeIselse:fromtyping_extensionsimportTypeIsas_TypeIsfromtorch.torch_versionimport__version__as__version____all__=["BoolStorage","BoolTensor","ByteStorage","ByteTensor","CharStorage","CharTensor","DoubleStorage","DoubleTensor","FloatStorage","FloatTensor","GradScaler","IntStorage","IntTensor","LongStorage","LongTensor","ShortStorage","ShortTensor","SymBool","SymFloat","SymInt","Tensor","TypedStorage","UntypedStorage","are_deterministic_algorithms_enabled","autocast","chunk","compile","cond","enable_grad","export","get_default_device","get_deterministic_debug_mode","get_device_module","get_float32_matmul_precision","get_rng_state","inference_mode","initial_seed","is_deterministic_algorithms_warn_only_enabled","is_storage","is_tensor","is_warn_always_enabled","load","lobpcg","manual_seed","matmul","no_grad","rand","randn","save","seed","set_default_device","set_default_tensor_type","set_deterministic_debug_mode","set_float32_matmul_precision","set_printoptions","set_rng_state","set_warn_always","split","stack","sym_float","sym_fresh_size","sym_int","sym_ite","sym_max","sym_min","sym_not","sym_sum","typename","unravel_index","use_deterministic_algorithms","vmap",]# Please keep this list sortedassert__all__==sorted(__all__)################################################################################# Load the extension module################################################################################ifsys.platform=="win32":def_load_dll_libraries()->None:importsysconfigfromtorch.versionimportcudaascuda_versionpfiles_path=os.getenv("ProgramFiles",r"C:\Program Files")py_dll_path=os.path.join(sys.exec_prefix,"Library","bin")th_dll_path=os.path.join(os.path.dirname(__file__),"lib")usebase_path=os.path.join(sysconfig.get_config_var("userbase"),"Library","bin")# When users create a virtualenv that inherits the base environment,# we will need to add the corresponding library directory into# DLL search directories. Otherwise, it will rely on `PATH` which# is dependent on user settings.ifsys.exec_prefix!=sys.base_exec_prefix:base_py_dll_path=os.path.join(sys.base_exec_prefix,"Library","bin")else:base_py_dll_path=""dll_paths=[pforpin(th_dll_path,py_dll_path,base_py_dll_path,usebase_path)ifos.path.exists(p)]ifnotbuiltins.any(os.path.exists(os.path.join(p,"nvToolsExt64_1.dll"))forpindll_paths):nvtoolsext_dll_path=os.path.join(os.getenv("NVTOOLSEXT_PATH",os.path.join(pfiles_path,"NVIDIA Corporation","NvToolsExt"),),"bin","x64",)else:nvtoolsext_dll_path=""ifcuda_versionandbuiltins.all(notglob.glob(os.path.join(p,"cudart64*.dll"))forpindll_paths):cuda_version_1=cuda_version.replace(".","_")cuda_path_var="CUDA_PATH_V"+cuda_version_1default_path=os.path.join(pfiles_path,"NVIDIA GPU Computing Toolkit","CUDA",f"v{cuda_version}")cuda_path=os.path.join(os.getenv(cuda_path_var,default_path),"bin")else:cuda_path=""dll_paths.extend(pforpin(nvtoolsext_dll_path,cuda_path)ifos.path.exists(p))kernel32=ctypes.WinDLL("kernel32.dll",use_last_error=True)with_load_library_flags=hasattr(kernel32,"AddDllDirectory")prev_error_mode=kernel32.SetErrorMode(0x0001)kernel32.LoadLibraryW.restype=ctypes.c_void_pifwith_load_library_flags:kernel32.LoadLibraryExW.restype=ctypes.c_void_pfordll_pathindll_paths:os.add_dll_directory(dll_path)try:ctypes.CDLL("vcruntime140.dll")ctypes.CDLL("msvcp140.dll")ifplatform.machine()!="ARM64":ctypes.CDLL("vcruntime140_1.dll")exceptOSError:print(textwrap.dedent(""" Microsoft Visual C++ Redistributable is not installed, this may lead to the DLL load failure. It can be downloaded at https://aka.ms/vs/16/release/vc_redist.x64.exe """).strip())dlls=glob.glob(os.path.join(th_dll_path,"*.dll"))path_patched=Falsefordllindlls:is_loaded=Falseifwith_load_library_flags:res=kernel32.LoadLibraryExW(dll,None,0x00001100)last_error=ctypes.get_last_error()ifresisNoneandlast_error!=126:err=ctypes.WinError(last_error)err.strerror+=(f' Error loading "{dll}" or one of its dependencies.')raiseerrelifresisnotNone:is_loaded=Trueifnotis_loaded:ifnotpath_patched:os.environ["PATH"]=";".join(dll_paths+[os.environ["PATH"]])path_patched=Trueres=kernel32.LoadLibraryW(dll)ifresisNone:err=ctypes.WinError(ctypes.get_last_error())err.strerror+=(f' Error loading "{dll}" or one of its dependencies.')raiseerrkernel32.SetErrorMode(prev_error_mode)_load_dll_libraries()del_load_dll_librariesdef_preload_cuda_deps(lib_folder:str,lib_name:str)->None:"""Preloads cuda deps if they could not be found otherwise."""# Should only be called on Linux if default path resolution have failedassertplatform.system()=="Linux","Should only be called on Linux"lib_path=Noneforpathinsys.path:nvidia_path=os.path.join(path,"nvidia")ifnotos.path.exists(nvidia_path):continuecandidate_lib_paths=glob.glob(os.path.join(nvidia_path,lib_folder,"lib",lib_name))# if path/nvidia/lib_folder/ is not found look in path/lib_folder/ifnotcandidate_lib_paths:candidate_lib_paths=glob.glob(os.path.join(path,lib_folder,"lib",lib_name))ifcandidate_lib_pathsandnotlib_path:lib_path=candidate_lib_paths[0]iflib_path:breakifnotlib_path:raiseValueError(f"{lib_name} not found in the system path {sys.path}")ctypes.CDLL(lib_path)# See Note [Global dependencies]def_load_global_deps()->None:if_running_with_deploy()orplatform.system()=="Windows":return# Determine the file extension based on the platformlib_ext=".dylib"ifplatform.system()=="Darwin"else".so"lib_name=f"libtorch_global_deps{lib_ext}"here=os.path.abspath(__file__)global_deps_lib_path=os.path.join(os.path.dirname(here),"lib",lib_name)try:ctypes.CDLL(global_deps_lib_path,mode=ctypes.RTLD_GLOBAL)# Workaround slim-wheel CUDA dependency bugs in cusparse and cudnn by preloading nvjitlink# and nvrtc. In CUDA-12.4+ cusparse depends on nvjitlink, but does not have rpath when# shipped as wheel, which results in OS picking wrong/older version of nvjitlink library# if `LD_LIBRARY_PATH` is defined, see https://github.com/pytorch/pytorch/issues/138460# Similar issue exist in cudnn that dynamically loads nvrtc, unaware of its relative path.# See https://github.com/pytorch/pytorch/issues/145580try:withopen("/proc/self/maps")asf:_maps=f.read()# libtorch_global_deps.so always depends in cudart, check if its installed via wheelif"nvidia/cuda_runtime/lib/libcudart.so"notin_maps:return# If all above-mentioned conditions are met, preload nvrtc and nvjitlink# Please note that order are important for CUDA-11.8 , as nvjitlink does not exist there_preload_cuda_deps("cuda_nvrtc","libnvrtc.so.*[0-9]")_preload_cuda_deps("nvjitlink","libnvJitLink.so.*[0-9]")exceptException:passexceptOSErroraserr:# Can only happen for wheel with cuda libs as PYPI deps# As PyTorch is not purelib, but nvidia-*-cu12 iscuda_libs:dict[str,str]={"cublas":"libcublas.so.*[0-9]","cudnn":"libcudnn.so.*[0-9]","cuda_nvrtc":"libnvrtc.so.*[0-9]","cuda_runtime":"libcudart.so.*[0-9]","cuda_cupti":"libcupti.so.*[0-9]","cufft":"libcufft.so.*[0-9]","curand":"libcurand.so.*[0-9]","nvjitlink":"libnvJitLink.so.*[0-9]","cusparse":"libcusparse.so.*[0-9]","cusparselt":"libcusparseLt.so.*[0-9]","cusolver":"libcusolver.so.*[0-9]","nccl":"libnccl.so.*[0-9]","nvtx":"libnvToolsExt.so.*[0-9]",}is_cuda_lib_err=[libforlibincuda_libs.values()iflib.split(".")[0]inerr.args[0]]ifnotis_cuda_lib_err:raiseerrforlib_folder,lib_nameincuda_libs.items():_preload_cuda_deps(lib_folder,lib_name)ctypes.CDLL(global_deps_lib_path,mode=ctypes.RTLD_GLOBAL)if(USE_RTLD_GLOBAL_WITH_LIBTORCHoros.getenv("TORCH_USE_RTLD_GLOBAL"))and(_running_with_deploy()orplatform.system()!="Windows"):# Do it the hard way. You might want to load libtorch with RTLD_GLOBAL in a# few circumstances:## 1. You're in a build environment (e.g., fbcode) where# libtorch_global_deps is not available, but you still need# to get mkl to link in with RTLD_GLOBAL or it will just# not work.## 2. You're trying to run PyTorch under UBSAN and you need# to ensure that only one copy of libtorch is loaded, so# vptr checks work properly## If you're using this setting, you must verify that all the libraries# you load consistently use the same libstdc++, or you may have# mysterious segfaults.#old_flags=sys.getdlopenflags()sys.setdlopenflags(os.RTLD_GLOBAL|os.RTLD_LAZY)fromtorch._Cimport*# noqa: F403sys.setdlopenflags(old_flags)delold_flagselse:# Easy way. You want this most of the time, because it will prevent# C++ symbols from libtorch clobbering C++ symbols from other# libraries, leading to mysterious segfaults.## If building in an environment where libtorch_global_deps isn't available# like parts of fbsource, but where RTLD_GLOBAL causes segfaults, you will# want USE_RTLD_GLOBAL_WITH_LIBTORCH = False and USE_GLOBAL_DEPS = False## See Note [Global dependencies]ifUSE_GLOBAL_DEPS:_load_global_deps()fromtorch._Cimport*# noqa: F403
[docs]classSymInt:""" Like an int (including magic methods), but redirects all operations on the wrapped node. This is used in particular to symbolically record operations in the symbolic shape workflow. """def__init__(self,node):# This field MUST be named node; C++ binding code assumes that this# class has a field named node that stores SymNodeself.node=nodedef__bool__(self):returnbuiltins.bool(self!=0)def__int__(self):returnself.node.int_()def__index__(self):returnself.node.int_()# Magic methods installed by torch.fx.experimental.sym_nodedef__round__(self,ndigits=None):returnselfdef__truediv__(self,other):ifisinstance(other,(builtins.float,SymFloat)):returnsym_float(self).__float_truediv__(other)ifnotisinstance(other,(builtins.int,SymInt)):returnNotImplementedreturnself.__int_truediv__(other)def__rtruediv__(self,other):ifisinstance(other,(builtins.float,SymFloat)):returnsym_float(self).__rfloat_truediv__(other)ifnotisinstance(other,(builtins.int,SymInt)):returnNotImplementedreturnself.__rint_truediv__(other)def__floordiv__(self,other):ifisinstance(other,(builtins.float,SymFloat)):returnsym_float(math.floor(sym_float(self)/other))ifnotisinstance(other,(builtins.int,SymInt)):returnNotImplementedreturnself.__int_floordiv__(other)def__rfloordiv__(self,other):ifisinstance(other,(builtins.float,SymFloat)):returnsym_float(math.floor(other/sym_float(self)))ifnotisinstance(other,(builtins.int,SymInt)):returnNotImplementedreturnself.__rint_floordiv__(other)# nb: complex is impossible to handle correctly lol, with# negative base and integral float need to diverge semantics and# just always return complex. Neener neener pretend this problem# doesn't existdef__pow__(self,other):ifisinstance(other,(builtins.float,SymFloat)):returnsym_float(self).__pow__(other)ifnotisinstance(other,(builtins.int,SymInt)):returnNotImplemented# Guards! This guard is necessary because we need to know it to# determine the output type of this operationifother>=0:returnself.__pow_by_natural__(other)else:# Mercifully, when the exponent is negative, Python just promotes# to doubles and does a float pow:## if (Py_SIZE(b) < 0 && c == NULL) {# /* if exponent is negative and there's no modulus:# return a float. This works because we know# that this calls float_pow() which converts its# arguments to double. */# Py_DECREF(a);# Py_DECREF(b);# return PyFloat_Type.tp_as_number->nb_power(v, w, x);# }returnsym_float(self).__pow__(sym_float(other))def__rpow__(self,other):ifisinstance(other,(builtins.float,SymFloat)):returnsym_float(self).__rpow__(other)ifnotisinstance(other,(builtins.int,SymInt)):returnNotImplementedifself>=0:# self is exponentreturnself.__rpow_by_natural__(other)else:returnsym_float(self).__rpow__(sym_float(other))def__eq__(self,other:object)->builtins.bool:raiseTypeError("type stub not overridden")def__lt__(self,other)->builtins.bool:raiseTypeError("type stub not overridden")def__gt__(self,other)->builtins.bool:raiseTypeError("type stub not overridden")def__le__(self,other)->builtins.bool:raiseTypeError("type stub not overridden")def__ge__(self,other)->builtins.bool:raiseTypeError("type stub not overridden")def__add__(self,other)->"SymInt":raiseTypeError("type stub not overridden")def__radd__(self,other)->"SymInt":raiseTypeError("type stub not overridden")def__rmul__(self,other)->"SymInt":raiseTypeError("type stub not overridden")def__mod__(self,other:"IntLikeType")->"SymInt":raiseTypeError("type stub not overridden")def__mul__(self,other)->"SymInt":raiseTypeError("type stub not overridden")def__pow_by_natural__(self,other)->"SymInt":raiseTypeError("type stub not overridden")def__rpow_by_natural__(self,other)->"SymInt":raiseTypeError("type stub not overridden")def__int_truediv__(self,other)->"SymFloat":raiseTypeError("type stub not overridden")def__rint_truediv__(self,other)->"SymFloat":raiseTypeError("type stub not overridden")def__int_floordiv__(self,other)->"SymFloat":raiseTypeError("type stub not overridden")def__rint_floordiv__(self,other)->"SymFloat":raiseTypeError("type stub not overridden")def__sym_max__(self,other):raiseTypeError("type stub not overridden")def__sym_min__(self,other):raiseTypeError("type stub not overridden")def__sym_float__(self):raiseTypeError("type stub not overridden")def__neg__(self):raiseTypeError("type stub not overridden")def__sub__(self,other:"IntLikeType")->"SymInt":raiseTypeError("type stub not overridden")def__rsub__(self,other:"IntLikeType")->"SymInt":raiseTypeError("type stub not overridden")def__and__(self,other)->"SymInt":raiseTypeError("type stub not overridden")def__or__(self,other)->"SymInt":raiseTypeError("type stub not overridden")def__repr__(self):returnself.node._graph_repr()def_sympy_(self):returnself.node.exprdef__hash__(self)->builtins.int:ifself.node.is_nested_int():returnhash(self.node.nested_int())else:# We could support constant SymInts as well, but not doing it for nowraiseTypeError("unhashable type: non-nested SymInt")# TODO: Force specialization# This can't be done because the TypeError here is load bearing# for einops# https://github.com/arogozhnikov/einops/blob/6181e1e95dc58c00a3143c1726da1c6ee0463164/einops/einops.py#L237# return hash(builtins.int(self))
[docs]defas_integer_ratio(self)->tuple["SymInt",builtins.int]:"""Represent this int as an exact integer ratio"""returnself,1
defbit_length(self)->builtins.int:# TODO: A more relaxed guard is possible here, where you guard to# allow all integer quantities which would result in the same bit# length. We can also just make a dedicated Sympy function for# computing this quantity and represent it symbolically.returnbuiltins.int(self).bit_length()defconjugate(self)->"SymInt":returnself
[docs]classSymFloat:""" Like an float (including magic methods), but redirects all operations on the wrapped node. This is used in particular to symbolically record operations in the symbolic shape workflow. """def__init__(self,node):# This field MUST be named node; C++ binding code assumes that this# class has a field named node that stores SymNodeself.node=nodedef__truediv__(self,other):ifnotisinstance(other,(builtins.int,builtins.float,SymInt,SymFloat)):returnNotImplementedreturnself.__float_truediv__(sym_float(other))def__rtruediv__(self,other):ifnotisinstance(other,(builtins.int,builtins.float,SymInt,SymFloat)):returnNotImplementedreturnself.__rfloat_truediv__(sym_float(other))def__floordiv__(self,other):ifnotisinstance(other,(builtins.int,builtins.float,SymInt,SymFloat)):returnNotImplementedreturnsym_float(math.floor(self/sym_float(other)))def__rfloordiv__(self,other):ifnotisinstance(other,(builtins.int,builtins.float,SymInt,SymFloat)):returnNotImplementedreturnsym_float(math.floor(sym_float(other)/self))def__bool__(self):returnself.node.bool_()def__float__(self):returnself.node.guard_float("",0)# Symbolic power does NOT work with negative base, this is to avoid# potential complex outputsdef__pow__(self,other):ifnotisinstance(other,(builtins.int,builtins.float,SymInt,SymFloat)):returnNotImplementedtorch._check(self>=0)returnself.__float_pow__(other)def__rpow__(self,other):ifnotisinstance(other,(builtins.int,builtins.float,SymInt,SymFloat)):returnNotImplementedtorch._check(other>=0)returnself.__rfloat_pow__(other)# Magic methods installed by torch.fx.experimental.sym_nodedef__eq__(self,other:object)->builtins.bool:raiseTypeError("type stub not overridden")def__lt__(self,other)->builtins.bool:raiseTypeError("type stub not overridden")def__gt__(self,other)->builtins.bool:raiseTypeError("type stub not overridden")def__le__(self,other)->builtins.bool:raiseTypeError("type stub not overridden")def__ge__(self,other)->builtins.bool:raiseTypeError("type stub not overridden")def__float_pow__(self,other)->"SymFloat":raiseTypeError("type stub not overridden")def__rfloat_pow__(self,other)->"SymFloat":raiseTypeError("type stub not overridden")def__float_truediv__(self,other)->"SymFloat":raiseTypeError("type stub not overridden")def__rfloat_truediv__(self,other)->"SymFloat":raiseTypeError("type stub not overridden")def__trunc__(self):raiseTypeError("type stub not overridden")def__sym_max__(self,other):raiseTypeError("type stub not overridden")def__sym_min__(self,other):raiseTypeError("type stub not overridden")def__sym_int__(self):raiseTypeError("type stub not overridden")
[docs]defis_integer(self):"""Return True if the float is an integer."""raiseTypeError("type stub not overridden")
[docs]defas_integer_ratio(self)->tuple[builtins.int,builtins.int]:"""Represent this float as an exact integer ratio"""returnbuiltins.float(self).as_integer_ratio()
[docs]defconjugate(self)->"SymFloat":"""Returns the complex conjugate of the float."""returnself
[docs]defhex(self)->str:"""Returns the hexadecimal representation of the float."""returnself.node.guard_float("",0).hex()
[docs]classSymBool:""" Like an bool (including magic methods), but redirects all operations on the wrapped node. This is used in particular to symbolically record operations in the symbolic shape workflow. Unlike regular bools, regular boolean operators will force extra guards instead of symbolically evaluate. Use the bitwise operators instead to handle this. """def__init__(self,node):# This field MUST be named node; C++ binding code assumes that this# class has a field named node that stores SymNodeself.node=nodedef__bool__(self):returnself.node.bool_()def__int__(self):returnbuiltins.int(self.node.bool_())# Magic methods installed by torch.fx.experimental.sym_nodedef__and__(self,other)->"SymBool":raiseTypeError("type stub not overridden")def__or__(self,other)->"SymBool":raiseTypeError("type stub not overridden")# We very carefully define __sym_not__, and not a number of other# plausible alternatives:## - We do not override __not__ because this is not a real magic# method; you cannot override the meaning of the not builtin in# Python. We use the name 'sym_not' to clarify that in user code you# cannot use the builtin not or operator.not_ or operator.__not__ and# hit this magic method; you must use our custom sym_not operator.## - We do not override the __invert__ method because SymBool is# meant to be usable in situations where bool is expected. However,# bitwise negation ~a does the wrong thing with booleans (because# bool is a subclass of int, so ~1 = -2 which is not falseish.)# This would be a giant footgun, so we get around it by defining# our own operator. Note that bitwise and/or do the right thing,# so we reuse the conventional operators there for readability.#def__sym_not__(self)->"SymBool":raiseTypeError("type stub not overridden")def__sym_ite__(self,then_val,else_val):raiseTypeError("type stub not overridden")def__eq__(self,other)->builtins.bool:raiseTypeError("type stub not overridden")def__repr__(self):returnself.node._graph_repr()def_sympy_(self):returnself.node.exprdef__hash__(self):ifself.node.is_constant():returnhash(self.node.bool_())else:# Force specializationreturnhash(builtins.bool(self))
[docs]defsym_not(a):r"""SymInt-aware utility for logical negation. Args: a (SymBool or bool): Object to negate """importsympyifoverrides.has_torch_function_unary(a):returnoverrides.handle_torch_function(sym_not,(a,),a)ifhasattr(a,"__sym_not__"):returna.__sym_not__()ifisinstance(a,sympy.Basic):return~a# type: ignore[operator]returnnota
[docs]defsym_float(a):r"""SymInt-aware utility for float casting. Args: a (SymInt, SymFloat, or object): Object to cast """ifoverrides.has_torch_function_unary(a):returnoverrides.handle_torch_function(sym_float,(a,),a)ifisinstance(a,SymFloat):returnaelifhasattr(a,"__sym_float__"):returna.__sym_float__()returnbuiltins.float(a)# type: ignore[operator]
[docs]defsym_int(a):r"""SymInt-aware utility for int casting. Args: a (SymInt, SymFloat, or object): Object to cast """ifoverrides.has_torch_function_unary(a):returnoverrides.handle_torch_function(sym_int,(a,),a)ifisinstance(a,SymInt):returnaelifisinstance(a,SymFloat):returnmath.trunc(a)returnbuiltins.int(a)# type: ignore[operator]
[docs]defsym_max(a,b):""" SymInt-aware utility for max which avoids branching on a < b. Unlike builtins.max(), this only works for int/float, and it always promotes to float if any argument is float (unlike builtins.max, which will faithfully preserve the type of the input argument). """ifoverrides.has_torch_function((a,b)):returnoverrides.handle_torch_function(sym_max,(a,b),a,b)ifisinstance(a,(SymInt,SymFloat)):returna.__sym_max__(b)elifisinstance(b,(SymInt,SymFloat)):# Due to promotion semantics, this is operator is commutative:# max(1, 1.0) === max(1.0, 1) === 1.0returnb.__sym_max__(a)# TODO: Probably can make bool work too, just lazyall_types,float_types=__all_and_float_types()assertisinstance(a,all_types),type(a)assertisinstance(b,all_types),type(b)ifisinstance(a,float_types)orisinstance(b,float_types):returnbuiltins.float(builtins.max(a,b))# type: ignore[call-overload]else:returnbuiltins.max(a,b)# type: ignore[call-overload]
[docs]defsym_min(a,b):"""SymInt-aware utility for min()."""ifoverrides.has_torch_function((a,b)):returnoverrides.handle_torch_function(sym_min,(a,b),a,b)ifisinstance(a,(SymInt,SymFloat)):returna.__sym_min__(b)elifisinstance(b,(SymInt,SymFloat)):returnb.__sym_min__(a)all_types,float_types=__all_and_float_types()assertisinstance(a,all_types),type(a)assertisinstance(b,all_types),type(b)ifisinstance(a,float_types)orisinstance(b,float_types):returnbuiltins.float(builtins.min(a,b))# type: ignore[call-overload]else:returnbuiltins.min(a,b)# type: ignore[call-overload]
[docs]defsym_sum(args):""" N-ary add which is faster to compute for long lists than iterated binary addition. Only does something special for integers. """ifoverrides.has_torch_function(args):returnoverrides.handle_torch_function(sym_sum,args,args)found=Noneforainargs:ifnotisinstance(a,(SymInt,builtins.int)):returnbuiltins.sum(args)ifisinstance(a,SymInt):found=a.nodeiffoundisNone:returnbuiltins.sum(args)fromtorch.fx.experimental.sym_nodeimportto_node,wrap_nodereturnwrap_node(found.sym_sum(tuple(to_node(found,a)forainargs)))
# Drop in replacement for math.sqrt, math.sin, math.cos etcdef_get_sym_math_fn(name):deffn(a):ifoverrides.has_torch_function_unary(a):returnoverrides.handle_torch_function(fn,(a,),a)ifisinstance(a,SymInt):a=torch.sym_float(a)ifhasattr(a,f"__sym_{name}__"):returngetattr(a,f"__sym_{name}__")()returngetattr(math,name)(a)returnfn__fn,__name,__sym_name=None,"",""for__namein("sqrt","cos","cosh","sin","sinh","tan","tanh","asin","acos","atan","log2",):__sym_name=f"_sym_{__name}"__fn=_get_sym_math_fn(__name)__fn.__qualname__=__fn.__name__=__sym_nameglobals()[__sym_name]=__fndel__fn,__name,__sym_name,_get_sym_math_fn# Adding temporary shortcutsym_sqrt=globals()["_sym_sqrt"]__all__.append("sym_sqrt")
# Check to see if we can load C extensions, and if not provide some guidance# on what the problem might be.try:# _initExtension is chosen (arbitrarily) as a sentinel.fromtorch._Cimport_initExtensionexceptImportError:importtorch._Cas_C_for_compiled_check# The __file__ check only works for Python 3.7 and above.if_C_for_compiled_check.__file__isNone:raiseImportError(textwrap.dedent(""" Failed to load PyTorch C extensions: It appears that PyTorch has loaded the `torch/_C` folder of the PyTorch repository rather than the C extensions which are expected in the `torch._C` namespace. This can occur when using the `install` workflow. e.g. $ python setup.py install && python -c "import torch" This error can generally be solved using the `develop` workflow $ python setup.py develop && python -c "import torch" # This should succeed or by running Python from a different directory. """).strip())fromNoneraise# If __file__ is not None the cause is unknown, so just re-raise.# The torch._C submodule is already loaded via `from torch._C import *` above# Make an explicit reference to the _C submodule to appease lintersfromtorchimport_Cas_C__name,__obj="",Nonefor__nameindir(_C):if__name[0]!="_"andnot__name.endswith("Base"):__all__.append(__name)__obj=getattr(_C,__name)ifcallable(__obj)orinspect.isclass(__obj):if__obj.__module__!=__name__:# "torch"# TODO: fix their module from C++ sideif__namenotin{"DisableTorchFunctionSubclass","DisableTorchFunction","Generator",}:__obj.__module__=__name__# "torch"elif__name=="TensorBase":# issue 109438 / pr 109940. Prevent TensorBase from being copied into torch.delattr(sys.modules[__name__],__name)del__name,__objifnotTYPE_CHECKING:# issue 38137 and python issue 43367. Submodules of a C extension are# non-standard, and attributes of those submodules cannot be pickled since# pickle expect to be able to import them as "from _C.sub import attr"# which fails with "_C is not a packagedef_import_extension_to_sys_modules(module,memo=None):ifmemoisNone:memo=set()ifmoduleinmemo:returnmemo.add(module)module_name=module.__name__fornameindir(module):member=getattr(module,name)member_name=getattr(member,"__name__","")ifinspect.ismodule(member)andmember_name.startswith(module_name):sys.modules.setdefault(member_name,member)# Recurse for submodules (e.g., `_C._dynamo.eval_frame`)_import_extension_to_sys_modules(member,memo)_import_extension_to_sys_modules(_C)del_import_extension_to_sys_modules################################################################################# Define basic utilities################################################################################deftypename(obj:_Any,/)->str:""" String representation of the type of an object. This function returns a fully qualified string representation of an object's type. Args: obj (object): The object whose type to represent Returns: str: the type of the object `o` Example: >>> x = torch.tensor([1, 2, 3]) >>> torch.typename(x) 'torch.LongTensor' >>> torch.typename(torch.nn.Parameter) 'torch.nn.parameter.Parameter' """ifisinstance(obj,torch.Tensor):returnobj.type()module=getattr(obj,"__module__","")or""qualname=""ifhasattr(obj,"__qualname__"):qualname=obj.__qualname__elifhasattr(obj,"__name__"):qualname=obj.__name__else:module=obj.__class__.__module__or""qualname=obj.__class__.__qualname__ifmodulein{"","builtins"}:returnqualnamereturnf"{module}.{qualname}"
[docs]defis_tensor(obj:_Any,/)->_TypeIs["torch.Tensor"]:r"""Returns True if `obj` is a PyTorch tensor. Note that this function is simply doing ``isinstance(obj, Tensor)``. Using that ``isinstance`` check is better for typechecking with mypy, and more explicit - so it's recommended to use that instead of ``is_tensor``. Args: obj (object): Object to test Example:: >>> x = torch.tensor([1, 2, 3]) >>> torch.is_tensor(x) True """returnisinstance(obj,torch.Tensor)
[docs]defis_storage(obj:_Any,/)->_TypeIs[_Union["TypedStorage","UntypedStorage"]]:r"""Returns True if `obj` is a PyTorch storage object. Args: obj (Object): Object to test """returntype(obj)in_storage_classes
_GLOBAL_DEVICE_CONTEXT=threading.local()
[docs]defget_default_device()->"torch.device":r"""Gets the default ``torch.Tensor`` to be allocated on ``device``"""global_GLOBAL_DEVICE_CONTEXTifhasattr(_GLOBAL_DEVICE_CONTEXT,"device_context"):device=_GLOBAL_DEVICE_CONTEXT.device_context.deviceifdevice.indexisnotNone:returndeviceelse:# TODO: Call like get_device_index() method corresponding to# each device typereturntorch.tensor([]).deviceelse:returntorch.device("cpu")
[docs]defset_default_device(device:_Optional[_Union["torch.device",str,builtins.int]],)->None:"""Sets the default ``torch.Tensor`` to be allocated on ``device``. This does not affect factory function calls which are called with an explicit ``device`` argument. Factory calls will be performed as if they were passed ``device`` as an argument. To only temporarily change the default device instead of setting it globally, use ``with torch.device(device):`` instead. The default device is initially ``cpu``. If you set the default tensor device to another device (e.g., ``cuda``) without a device index, tensors will be allocated on whatever the current device for the device type, even after :func:`torch.cuda.set_device` is called. .. warning:: This function imposes a slight performance cost on every Python call to the torch API (not just factory functions). If this is causing problems for you, please comment on https://github.com/pytorch/pytorch/issues/92701 .. note:: This doesn't affect functions that create tensors that share the same memory as the input, like: :func:`torch.from_numpy` and :func:`torch.frombuffer` Args: device (device or string): the device to set as default Example:: >>> # xdoctest: +SKIP("requires cuda, changes global state") >>> torch.get_default_device() device(type='cpu') >>> torch.set_default_device('cuda') # current device is 0 >>> torch.get_default_device() device(type='cuda', index=0) >>> torch.set_default_device('cuda') >>> torch.cuda.set_device('cuda:1') # current device is 1 >>> torch.get_default_device() device(type='cuda', index=1) >>> torch.set_default_device('cuda:1') >>> torch.get_default_device() device(type='cuda', index=1) """global_GLOBAL_DEVICE_CONTEXTifhasattr(_GLOBAL_DEVICE_CONTEXT,"device_context"):device_context=_GLOBAL_DEVICE_CONTEXT.device_contextifdevice_contextisnotNone:device_context.__exit__(None,None,None)ifdeviceisNone:device_context=Noneelse:fromtorch.utils._deviceimportDeviceContextdevice_context=DeviceContext(device)device_context.__enter__()_GLOBAL_DEVICE_CONTEXT.device_context=device_context
[docs]defset_default_tensor_type(t:_Union[type["torch.Tensor"],str],/)->None:r""" .. warning:: This function is deprecated as of PyTorch 2.1, please use :func:`torch.set_default_dtype()` and :func:`torch.set_default_device()` as alternatives. Sets the default ``torch.Tensor`` type to floating point tensor type ``t``. This type will also be used as default floating point type for type inference in :func:`torch.tensor`. The default floating point tensor type is initially ``torch.FloatTensor``. Args: t (type or string): the floating point tensor type or its name Example:: >>> # xdoctest: +SKIP("Other tests may have changed the default type. Can we reset it?") >>> torch.tensor([1.2, 3]).dtype # initial default for floating point is torch.float32 torch.float32 >>> torch.set_default_tensor_type(torch.DoubleTensor) >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor torch.float64 """ifisinstance(t,str):t=_import_dotted_name(t)_C._set_default_tensor_type(t)
[docs]defset_default_dtype(d:"torch.dtype",/)->None:r""" Sets the default floating point dtype to :attr:`d`. Supports floating point dtype as inputs. Other dtypes will cause torch to raise an exception. When PyTorch is initialized its default floating point dtype is torch.float32, and the intent of set_default_dtype(torch.float64) is to facilitate NumPy-like type inference. The default floating point dtype is used to: 1. Implicitly determine the default complex dtype. When the default floating type is float16, the default complex dtype is complex32. For float32, the default complex dtype is complex64. For float64, it is complex128. For bfloat16, an exception will be raised because there is no corresponding complex type for bfloat16. 2. Infer the dtype for tensors constructed using Python floats or complex Python numbers. See examples below. 3. Determine the result of type promotion between bool and integer tensors and Python floats and complex Python numbers. Args: d (:class:`torch.dtype`): the floating point dtype to make the default. Example: >>> # xdoctest: +SKIP("Other tests may have changed the default type. Can we reset it?") >>> # initial default for floating point is torch.float32 >>> # Python floats are interpreted as float32 >>> torch.tensor([1.2, 3]).dtype torch.float32 >>> # initial default for floating point is torch.complex64 >>> # Complex Python numbers are interpreted as complex64 >>> torch.tensor([1.2, 3j]).dtype torch.complex64 >>> torch.set_default_dtype(torch.float64) >>> # Python floats are now interpreted as float64 >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor torch.float64 >>> # Complex Python numbers are now interpreted as complex128 >>> torch.tensor([1.2, 3j]).dtype # a new complex tensor torch.complex128 >>> torch.set_default_dtype(torch.float16) >>> # Python floats are now interpreted as float16 >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor torch.float16 >>> # Complex Python numbers are now interpreted as complex128 >>> torch.tensor([1.2, 3j]).dtype # a new complex tensor torch.complex32 """_C._set_default_dtype(d)
[docs]defuse_deterministic_algorithms(mode:builtins.bool,*,warn_only:builtins.bool=False,)->None:r"""Sets whether PyTorch operations must use "deterministic" algorithms. That is, algorithms which, given the same input, and when run on the same software and hardware, always produce the same output. When enabled, operations will use deterministic algorithms when available, and if only nondeterministic algorithms are available they will throw a :class:`RuntimeError` when called. .. note:: This setting alone is not always enough to make an application reproducible. Refer to :ref:`reproducibility` for more information. .. note:: :func:`torch.set_deterministic_debug_mode` offers an alternative interface for this feature. The following normally-nondeterministic operations will act deterministically when ``mode=True``: * :class:`torch.nn.Conv1d` when called on CUDA tensor * :class:`torch.nn.Conv2d` when called on CUDA tensor * :class:`torch.nn.Conv3d` when called on CUDA tensor * :class:`torch.nn.ConvTranspose1d` when called on CUDA tensor * :class:`torch.nn.ConvTranspose2d` when called on CUDA tensor * :class:`torch.nn.ConvTranspose3d` when called on CUDA tensor * :class:`torch.nn.ReplicationPad2d` when attempting to differentiate a CUDA tensor * :func:`torch.bmm` when called on sparse-dense CUDA tensors * :func:`torch.Tensor.__getitem__` when attempting to differentiate a CPU tensor and the index is a list of tensors * :func:`torch.Tensor.index_put` with ``accumulate=False`` * :func:`torch.Tensor.index_put` with ``accumulate=True`` when called on a CPU tensor * :func:`torch.Tensor.put_` with ``accumulate=True`` when called on a CPU tensor * :func:`torch.Tensor.scatter_add_` when called on a CUDA tensor * :func:`torch.gather` when called on a CUDA tensor that requires grad * :func:`torch.index_add` when called on CUDA tensor * :func:`torch.index_select` when attempting to differentiate a CUDA tensor * :func:`torch.repeat_interleave` when attempting to differentiate a CUDA tensor * :func:`torch.Tensor.index_copy` when called on a CPU or CUDA tensor * :func:`torch.Tensor.scatter` when `src` type is Tensor and called on CUDA tensor * :func:`torch.Tensor.scatter_reduce` when ``reduce='sum'`` or ``reduce='mean'`` and called on CUDA tensor The following normally-nondeterministic operations will throw a :class:`RuntimeError` when ``mode=True``: * :class:`torch.nn.AvgPool3d` when attempting to differentiate a CUDA tensor * :class:`torch.nn.AdaptiveAvgPool2d` when attempting to differentiate a CUDA tensor * :class:`torch.nn.AdaptiveAvgPool3d` when attempting to differentiate a CUDA tensor * :class:`torch.nn.MaxPool3d` when attempting to differentiate a CUDA tensor * :class:`torch.nn.AdaptiveMaxPool2d` when attempting to differentiate a CUDA tensor * :class:`torch.nn.FractionalMaxPool2d` when attempting to differentiate a CUDA tensor * :class:`torch.nn.FractionalMaxPool3d` when attempting to differentiate a CUDA tensor * :class:`torch.nn.MaxUnpool1d` * :class:`torch.nn.MaxUnpool2d` * :class:`torch.nn.MaxUnpool3d` * :func:`torch.nn.functional.interpolate` when attempting to differentiate a CUDA tensor and one of the following modes is used: - ``linear`` - ``bilinear`` - ``bicubic`` - ``trilinear`` * :class:`torch.nn.ReflectionPad1d` when attempting to differentiate a CUDA tensor * :class:`torch.nn.ReflectionPad2d` when attempting to differentiate a CUDA tensor * :class:`torch.nn.ReflectionPad3d` when attempting to differentiate a CUDA tensor * :class:`torch.nn.ReplicationPad1d` when attempting to differentiate a CUDA tensor * :class:`torch.nn.ReplicationPad3d` when attempting to differentiate a CUDA tensor * :class:`torch.nn.NLLLoss` when called on a CUDA tensor * :class:`torch.nn.CTCLoss` when attempting to differentiate a CUDA tensor * :class:`torch.nn.EmbeddingBag` when attempting to differentiate a CUDA tensor when ``mode='max'`` * :func:`torch.Tensor.put_` when ``accumulate=False`` * :func:`torch.Tensor.put_` when ``accumulate=True`` and called on a CUDA tensor * :func:`torch.histc` when called on a CUDA tensor * :func:`torch.bincount` when called on a CUDA tensor and ``weights`` tensor is given * :func:`torch.kthvalue` with called on a CUDA tensor * :func:`torch.median` with indices output when called on a CUDA tensor * :func:`torch.nn.functional.grid_sample` when attempting to differentiate a CUDA tensor * :func:`torch.cumsum` when called on a CUDA tensor when dtype is floating point or complex * :func:`torch.Tensor.scatter_reduce` when ``reduce='prod'`` and called on CUDA tensor * :func:`torch.Tensor.resize_` when called with a quantized tensor In addition, several operations fill uninitialized memory when this setting is turned on and when :attr:`torch.utils.deterministic.fill_uninitialized_memory` is turned on. See the documentation for that attribute for more information. A handful of CUDA operations are nondeterministic if the CUDA version is 10.2 or greater, unless the environment variable ``CUBLAS_WORKSPACE_CONFIG=:4096:8`` or ``CUBLAS_WORKSPACE_CONFIG=:16:8`` is set. See the CUDA documentation for more details: `<https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility>`_ If one of these environment variable configurations is not set, a :class:`RuntimeError` will be raised from these operations when called with CUDA tensors: * :func:`torch.mm` * :func:`torch.mv` * :func:`torch.bmm` Note that deterministic operations tend to have worse performance than nondeterministic operations. .. note:: This flag does not detect or prevent nondeterministic behavior caused by calling an inplace operation on a tensor with an internal memory overlap or by giving such a tensor as the :attr:`out` argument for an operation. In these cases, multiple writes of different data may target a single memory location, and the order of writes is not guaranteed. Args: mode (:class:`bool`): If True, makes potentially nondeterministic operations switch to a deterministic algorithm or throw a runtime error. If False, allows nondeterministic operations. Keyword args: warn_only (:class:`bool`, optional): If True, operations that do not have a deterministic implementation will throw a warning instead of an error. Default: ``False`` Example:: >>> # xdoctest: +SKIP >>> torch.use_deterministic_algorithms(True) # Forward mode nondeterministic error >>> torch.randn(10, device='cuda').kthvalue(1) ... RuntimeError: kthvalue CUDA does not have a deterministic implementation... # Backward mode nondeterministic error >>> torch.nn.AvgPool3d(1)(torch.randn(3, 4, 5, 6, requires_grad=True).cuda()).sum().backward() ... RuntimeError: avg_pool3d_backward_cuda does not have a deterministic implementation... """_C._set_deterministic_algorithms(mode,warn_only=warn_only)
[docs]defare_deterministic_algorithms_enabled()->builtins.bool:r"""Returns True if the global deterministic flag is turned on. Refer to :func:`torch.use_deterministic_algorithms` documentation for more details. """return_C._get_deterministic_algorithms()
[docs]defis_deterministic_algorithms_warn_only_enabled()->builtins.bool:r"""Returns True if the global deterministic flag is set to warn only. Refer to :func:`torch.use_deterministic_algorithms` documentation for more details. """return_C._get_deterministic_algorithms_warn_only()
[docs]defset_deterministic_debug_mode(debug_mode:_Union[builtins.int,str])->None:r"""Sets the debug mode for deterministic operations. .. note:: This is an alternative interface for :func:`torch.use_deterministic_algorithms`. Refer to that function's documentation for details about affected operations. Args: debug_mode(str or int): If "default" or 0, don't error or warn on nondeterministic operations. If "warn" or 1, warn on nondeterministic operations. If "error" or 2, error on nondeterministic operations. """# NOTE: builtins.int is used here because int in this scope resolves# to torch.intifnotisinstance(debug_mode,(builtins.int,str)):raiseTypeError(f"debug_mode must be str or int, but got {type(debug_mode)}")ifisinstance(debug_mode,str):ifdebug_mode=="default":debug_mode=0elifdebug_mode=="warn":debug_mode=1elifdebug_mode=="error":debug_mode=2else:raiseRuntimeError("invalid value of debug_mode, expected one of `default`, "f"`warn`, `error`, but got {debug_mode}")ifdebug_mode==0:_C._set_deterministic_algorithms(False)elifdebug_mode==1:_C._set_deterministic_algorithms(True,warn_only=True)elifdebug_mode==2:_C._set_deterministic_algorithms(True)else:raiseRuntimeError("invalid value of debug_mode, expected 0, 1, or 2, "f"but got {debug_mode}")
[docs]defget_deterministic_debug_mode()->builtins.int:r"""Returns the current value of the debug mode for deterministic operations. Refer to :func:`torch.set_deterministic_debug_mode` documentation for more details. """if_C._get_deterministic_algorithms():if_C._get_deterministic_algorithms_warn_only():return1else:return2else:return0
[docs]defget_float32_matmul_precision()->str:r"""Returns the current value of float32 matrix multiplication precision. Refer to :func:`torch.set_float32_matmul_precision` documentation for more details. """return_C._get_float32_matmul_precision()
[docs]defset_float32_matmul_precision(precision:str)->None:r"""Sets the internal precision of float32 matrix multiplications. Running float32 matrix multiplications in lower precision may significantly increase performance, and in some programs the loss of precision has a negligible impact. Supports three settings: * "highest", float32 matrix multiplications use the float32 datatype (24 mantissa bits with 23 bits explicitly stored) for internal computations. * "high", float32 matrix multiplications either use the TensorFloat32 datatype (10 mantissa bits explicitly stored) or treat each float32 number as the sum of two bfloat16 numbers (approximately 16 mantissa bits with 14 bits explicitly stored), if the appropriate fast matrix multiplication algorithms are available. Otherwise float32 matrix multiplications are computed as if the precision is "highest". See below for more information on the bfloat16 approach. * "medium", float32 matrix multiplications use the bfloat16 datatype (8 mantissa bits with 7 bits explicitly stored) for internal computations, if a fast matrix multiplication algorithm using that datatype internally is available. Otherwise float32 matrix multiplications are computed as if the precision is "high". When using "high" precision, float32 multiplications may use a bfloat16-based algorithm that is more complicated than simply truncating to some smaller number mantissa bits (e.g. 10 for TensorFloat32, 7 for bfloat16 explicitly stored). Refer to [Henry2019]_ for a complete description of this algorithm. To briefly explain here, the first step is to realize that we can perfectly encode a single float32 number as the sum of three bfloat16 numbers (because float32 has 23 mantissa bits while bfloat16 has 7 explicitly stored, and both have the same number of exponent bits). This means that the product of two float32 numbers can be exactly given by the sum of nine products of bfloat16 numbers. We can then trade accuracy for speed by dropping some of these products. The "high" precision algorithm specifically keeps only the three most significant products, which conveniently excludes all of the products involving the last 8 mantissa bits of either input. This means that we can represent our inputs as the sum of two bfloat16 numbers rather than three. Because bfloat16 fused-multiply-add (FMA) instructions are typically >10x faster than float32 ones, it's faster to do three multiplications and 2 additions with bfloat16 precision than it is to do a single multiplication with float32 precision. .. [Henry2019] http://arxiv.org/abs/1904.06376 .. note:: This does not change the output dtype of float32 matrix multiplications, it controls how the internal computation of the matrix multiplication is performed. .. note:: This does not change the precision of convolution operations. Other flags, like `torch.backends.cudnn.allow_tf32`, may control the precision of convolution operations. .. note:: This flag currently only affects one native device type: CUDA. If "high" or "medium" are set then the TensorFloat32 datatype will be used when computing float32 matrix multiplications, equivalent to setting `torch.backends.cuda.matmul.allow_tf32 = True`. When "highest" (the default) is set then the float32 datatype is used for internal computations, equivalent to setting `torch.backends.cuda.matmul.allow_tf32 = False`. Args: precision(str): can be set to "highest" (default), "high", or "medium" (see above). """_C._set_float32_matmul_precision(precision)
[docs]defset_warn_always(b:builtins.bool,/)->None:r"""When this flag is False (default) then some PyTorch warnings may only appear once per process. This helps avoid excessive warning information. Setting it to True causes these warnings to always appear, which may be helpful when debugging. Args: b (:class:`bool`): If True, force warnings to always be emitted If False, set to the default behaviour """_C._set_warnAlways(b)
[docs]defis_warn_always_enabled()->builtins.bool:r"""Returns True if the global warn_always flag is turned on. Refer to :func:`torch.set_warn_always` documentation for more details. """return_C._get_warnAlways()
################################################################################# Define error checking functions################################################################################# These error checking functions must be kept consistent with their C++# equivalents. Their C++ equivalents are mentioned where applicable.def_check_with(error_type,cond:_Union[builtins.bool,SymBool],message:_Callable[[],str],):# noqa: F811ifnotisinstance(cond,(builtins.bool,SymBool)):raiseTypeError(f"cond must be a bool, but got {type(cond)}")fromtorch.fx.experimental.symbolic_shapesimportexpect_trueifexpect_true(cond):return# error_type must be a subclass of Exception and not subclass of Warningassertissubclass(error_type,Exception)andnotissubclass(error_type,Warning)ifmessageisNone:message_evaluated=("Expected cond to be True, but got False. (Could this error ""message be improved? If so, please report an enhancement request ""to PyTorch.)")else:ifnotcallable(message):raiseTypeError("message must be a callable")message_evaluated=str(message())raiseerror_type(message_evaluated)def_check(cond,message=None):# noqa: F811r"""Throws error containing an optional message if the specified condition is False. Error type: ``RuntimeError`` C++ equivalent: ``TORCH_CHECK`` Args: cond (:class:`bool`): If False, throw error message (Callable, optional): Callable that returns either a string or an object that has a ``__str__()`` method to be used as the error message. Default: ``None`` """_check_with(RuntimeError,cond,message)def_check_is_size(i,message=None,*,max=None):"""Checks that a given integer is a valid size (i.e., is non-negative). You should use this over ``_check(i >= 0)`` because it can prevent ``GuardOnDataDependentSymNode`` exceptions by opting yourself into alternate semantics for ``guard_size_oblivious`` tests that treat values 0 and 1 equivalently to all other values. When max is not None, this specifies an upper bound equivalent to ``_check(i <= max)``. This bound is also subject to alternate semantics: in ``guard_size_oblivious`` tests, we assume that a constant max bound is treated equivalently to all other values. Symbolic max bounds are not yet supported. NB: Do NOT use this in contexts where a -1 size would be valid (indicating to infer the size from context, or if you should wrap-around or truncate). Only use this if the only valid value is an honest to goodness size. """# This is responsible for the expect_true_check(i>=0,message)fromtorch.fx.experimental.symbolic_shapesimport_advise_is_size_advise_is_size(i)ifmaxisnotNone:_check(i<=max,message)fromtorch.fx.experimental.symbolic_shapesimport_advise_is_bounded_advise_is_bounded(i,max)def_check_index(cond,message=None):# noqa: F811r"""Throws error containing an optional message if the specified condition is False. Error type: ``IndexError`` C++ equivalent: ``TORCH_CHECK_INDEX`` Args: cond (:class:`bool`): If False, throw error message (Callable, optional): Callable that returns either a string or an object that has a ``__str__()`` method to be used as the error message. Default: ``None`` """_check_with(IndexError,cond,message)def_check_value(cond,message=None):# noqa: F811r"""Throws error containing an optional message if the specified condition is False. Error type: ``ValueError`` C++ equivalent: ``TORCH_CHECK_VALUE`` Args: cond (:class:`bool`): If False, throw error message (Callable, optional): Callable that returns either a string or an object that has a ``__str__()`` method to be used as the error message. Default: ``None`` """_check_with(ValueError,cond,message)def_check_type(cond,message=None):# noqa: F811r"""Throws error containing an optional message if the specified condition is False. Error type: ``TypeError`` C++ equivalent: ``TORCH_CHECK_TYPE`` Args: cond (:class:`bool`): If False, throw error message (Callable, optional): Callable that returns either a string or an object that has a ``__str__()`` method to be used as the error message. Default: ``None`` """_check_with(TypeError,cond,message)def_check_not_implemented(cond,message=None):# noqa: F811r"""Throws error containing an optional message if the specified condition is False. Error type: ``NotImplementedError`` C++ equivalent: ``TORCH_CHECK_NOT_IMPLEMENTED`` Args: cond (:class:`bool`): If False, throw error message (Callable, optional): Callable that returns either a string or an object that has a ``__str__()`` method to be used as the error message. Default: ``None`` """_check_with(NotImplementedError,cond,message)def_check_tensor_all_with(error_type,cond,message=None):# noqa: F811ifnotis_tensor(cond):raiseTypeError(f"cond must be a tensor, but got {type(cond)}")ifnotcond.dtype==torch.bool:raiseTypeError(f"cond tensor must have dtype torch.bool, but got {cond.dtype}")_check_with(error_type,cond._is_all_true().item(),message)# type: ignore[arg-type]# C++ equivalent: `TORCH_CHECK_TENSOR_ALL`def_check_tensor_all(cond,message=None):# noqa: F811r"""Throws error containing an optional message if the specified condition is False. Error type: ``RuntimeError`` C++ equivalent: ``TORCH_CHECK_TENSOR_ALL`` Args: cond (:class:`torch.Tensor`): Tensor of dtype ``torch.bool``. If any element is ``False``, throw error message (Callable, optional): Callable that returns either a string or an object that has a ``__str__()`` method to be used as the error message. Default: ``None`` """_check_tensor_all_with(RuntimeError,cond,message)################################################################################# Define numeric constants################################################################################# For Python Array API (https://data-apis.org/array-api/latest/API_specification/constants.html) and# NumPy consistency (https://numpy.org/devdocs/reference/constants.html)frommathimporte,inf,nan,pinewaxis:None=None__all__.extend(["e","pi","nan","inf","newaxis"])################################################################################# Define Storage and Tensor classes################################################################################fromtorch._tensorimportTensor# usort: skip# needs to be after torch.Tensor is defined to avoid circular dependenciesfromtorchimportstorageasstorage# usort: skipfromtorch.storageimport(_LegacyStorage,_StorageBase,_warn_typed_storage_removal,TypedStorage,UntypedStorage,)# NOTE: New <type>Storage classes should never be added. When adding a new# dtype, use torch.storage.TypedStorage directly.
_storage_classes:set[type[_Union[TypedStorage,UntypedStorage]]]={UntypedStorage,DoubleStorage,FloatStorage,LongStorage,IntStorage,ShortStorage,CharStorage,ByteStorage,HalfStorage,BoolStorage,QUInt8Storage,QInt8Storage,QInt32Storage,BFloat16Storage,ComplexFloatStorage,ComplexDoubleStorage,QUInt4x2Storage,QUInt2x4Storage,TypedStorage,}# The _tensor_classes set is initialized by the call to initialize_python_bindings._tensor_classes:set[type["torch.Tensor"]]=set()# If you edit these imports, please update torch/__init__.py.in as wellfromtorchimportampasamp,randomasrandom,serializationasserializationfromtorch._tensor_strimportset_printoptionsfromtorch.ampimportautocast,GradScalerfromtorch.randomimportget_rng_state,initial_seed,manual_seed,seed,set_rng_statefromtorch.serializationimportload,save################################################################################# Initialize extension################################################################################# Shared memory manager needs to know the exact location of manager executabledef_manager_path():if_running_with_deploy()orplatform.system()=="Windows":returnb""path=get_file_path("torch","bin","torch_shm_manager")prepare_multiprocessing_environment(get_file_path("torch"))ifnotos.path.exists(path):raiseRuntimeError("Unable to find torch_shm_manager at "+path)returnpath.encode("utf-8")_C._initExtension(_manager_path())del_manager_path# Appease the type checker: it can't deal with direct setting of globals().# Note that we will see "too many" functions when reexporting this way; there# is not a good way to fix this problem. Perhaps, try to redesign VariableFunctions# so that this import is good enoughifTYPE_CHECKING:# Some type signatures pulled in from _VariableFunctions here clash with# signatures already imported. For now these clashes are ignored; see# PR #43339 for details.fromtorch._C._VariableFunctionsimport*# type: ignore[assignment, misc] # noqa: F403# Fixup segment_reduce visibility_segment_reduce=segment_reducedelsegment_reduce# noqa: F821# Ops not to be exposed in `torch` namespace,# mostly helper ops.PRIVATE_OPS=("unique_dim",)__name,__obj="",Nonefor__nameindir(_C._VariableFunctions):if__name.startswith("__")or__nameinPRIVATE_OPS:continue__obj=getattr(_C._VariableFunctions,__name)__obj.__module__=__name__# "torch"# Hide some APIs that should not be publicif__name=="segment_reduce":# TODO: Once the undocumented FC window is passed, remove the line bellowglobals()[__name]=__obj__name="_"+__nameglobals()[__name]=__objifnot__name.startswith("_"):__all__.append(__name)del__name,__obj################################################################################# Add torch.dtype instances to the public API################################################################################importtorch__all__.extend(namefornameindir(torch)ifisinstance(getattr(torch,name),torch.dtype))################################################################################# Import TorchDynamo's lazy APIs to avoid circular dependenices################################################################################# needs to be before from torch.functional import * to avoid circular dependenciesfromtorch._compileimport_disable_dynamo# usort: skip################################################################################# Import interface functions defined in Python################################################################################# needs to be after the above ATen bindings so we can overwrite from Python sidefromtorchimport_VFas_VF,functionalasfunctional# usort: skipfromtorch.functionalimport*# usort: skip # noqa: F403################################################################################# Remove unnecessary members################################################################################del_StorageBasedel_LegacyStorage################################################################################# Define _assert################################################################################# needs to be before the submodule imports to avoid circular dependencies
[docs]def_assert(condition,message):r"""A wrapper around Python's assert which is symbolically traceable."""iftype(condition)isnottorch.Tensorandoverrides.has_torch_function((condition,)):returnoverrides.handle_torch_function(_assert,(condition,),condition,message)assertcondition,message
################################################################################# Import most common subpackages################################################################################# Use the redundant form so that type checkers know that these are a part of# the public API. The "regular" import lines are there solely for the runtime# side effect of adding to the imported module's members for other users.# needs to be before import torch.nn as nn to avoid circular dependenciesfromtorch.autogradimport(# usort: skipenable_gradasenable_grad,inference_modeasinference_mode,no_gradasno_grad,set_grad_enabledasset_grad_enabled,)fromtorchimport(__config__as__config__,__future__as__future__,_awaitsas_awaits,acceleratorasaccelerator,autogradasautograd,backendsasbackends,cpuascpu,cudaascuda,distributedasdistributed,distributionsasdistributions,fftasfft,futuresasfutures,hubashub,jitasjit,linalgaslinalg,mpsasmps,mtiaasmtia,multiprocessingasmultiprocessing,nestedasnested,nnasnn,optimasoptim,overridesasoverrides,profilerasprofiler,sparseassparse,specialasspecial,testingastesting,typesastypes,utilsasutils,xpuasxpu,)fromtorch.signalimportwindowsaswindows# Quantized, sparse, AO, etc. should be last to get imported, as nothing# is expected to depend on them.fromtorchimportaoasao# usort: skip# nn.quant* depends on ao -- so should be after those.importtorch.nn.intrinsicimporttorch.nn.qatimporttorch.nn.quantizableimporttorch.nn.quantized_C._init_names(list(_storage_classes))# attach docstrings to torch and tensor functionsfromtorchimport_size_docs,_storage_docs,_tensor_docs,_torch_docsdel_torch_docs,_tensor_docs,_storage_docs,_size_docs
[docs]defcompiled_with_cxx11_abi()->builtins.bool:r"""Returns whether PyTorch was built with _GLIBCXX_USE_CXX11_ABI=1"""return_C._GLIBCXX_USE_CXX11_ABI
fromtorchimport_libraryas_library,_opsas_ops# Import the ops and classes "namespace"fromtorch._opsimportopsasops# usort: skipfromtorch._classesimportclassesasclasses# usort: skipsys.modules.setdefault(f"{__name__}.ops",ops)sys.modules.setdefault(f"{__name__}.classes",classes)# quantization depends on torch.fx and torch.ops# Import quantizationfromtorchimportquantizationasquantization# usort: skip# Import the quasi random samplerfromtorchimportquasirandomasquasirandom# usort: skip# If you are seeing this, it means that this call site was not checked if# the memory format could be preserved, and it was switched to old default# behaviour of contiguouslegacy_contiguous_format=contiguous_format# defined by _C._initExtension()# Register fork handler to initialize OpenMP in child processes (see gh-28389)fromtorch.multiprocessing._atforkimportregister_after_forkregister_after_fork(torch.get_num_threads)delregister_after_fork# Import tools that require fully imported torch (for applying# torch.jit.script as a decorator, for instance):fromtorch._lobpcgimportlobpcgaslobpcg# These were previously defined in native_functions.yaml and appeared on the# `torch` namespace, but we moved them to c10 dispatch to facilitate custom# class usage. We add these lines here to preserve backward compatibility.quantized_lstm=ops.aten.quantized_lstmquantized_gru=ops.aten.quantized_gru# Import experimental masked operations support. See# [RFC-0016](https://github.com/pytorch/rfcs/pull/27) for more# information.fromtorchimportmaskedasmasked# Import removed ops with error message about removalfromtorch._linalg_utilsimport(# type: ignore[misc]_symeigassymeig,eig,lstsq,matrix_rank,solve,)fromtorch.utils.dlpackimportfrom_dlpack,to_dlpackclass_TorchCompileInductorWrapper:compiler_name="inductor"def__init__(self,mode,options,dynamic):fromtorch._inductor.compiler_bisectorimportCompilerBisectorself.config:dict[str,_Any]={}self.dynamic=dynamicself.apply_mode(mode)self.apply_options(options)self.apply_options(CompilerBisector.get_config_change("inductor"))ifself.config.get("triton.cudagraphs",False):os.environ["DISABLE_CUPTI_LAZY_REINIT"]="1"# FIXME: CUDA Graph does not work well with CUPTI teardown.# 1) crashes on 1st lazy CUPTI re-init after teardown (CUDA 11)# 2) crashes on 2nd non-lazy CUPTI re-init after teardown (CUDA 12)# Workaround: turn off CUPTI teardown when using CUDA Graphs.os.environ["TEARDOWN_CUPTI"]="0"def__eq__(self,other):return(isinstance(other,_TorchCompileInductorWrapper)andself.config==other.configandself.dynamic==other.dynamic)defapply_mode(self,mode:_Optional[str]):ifmodeandmode!="default":fromtorch._inductorimportlist_mode_optionsself.apply_options(list_mode_options(mode,self.dynamic))defapply_options(self,options:_Optional[dict[str,_Any]]):ifnotoptions:returnfromtorch._inductorimportconfigcurrent_config:dict[str,_Any]=config.get_config_copy()forkey,valinoptions.items():attr_name=key.replace("-","_")ifattr_namenotincurrent_config:raiseRuntimeError(f"Unexpected optimization option {key}, known options are {list(current_config.keys())}")attr_type=config.get_type(attr_name)# type: ignore[attr-defined]# Subscriptable generic types don't support isinstance so skip the type# check. There doesn't seem to be a good way of checking membership without# 3rd party libraries.if_get_origin(attr_type)isNone:ifnotisinstance(val,attr_type):val_type_str=type(val).__name__expected_type_str=type(current_config[attr_name]).__name__raiseRuntimeError(f"Unexpected type of attr {key}, got {val_type_str} should be {expected_type_str}")self.config[attr_name]=valdef__call__(self,model_,inputs_):fromtorch._inductor.compile_fximportcompile_fxreturncompile_fx(model_,inputs_,config_patches=self.config)defget_compiler_config(self):fromtorch._inductor.compile_fximportget_patched_config_dictreturnget_patched_config_dict(config_patches=self.config)defreset(self):fromtorch._inductorimportconfigif"triton.cudagraphs"inself.configorconfig.triton.cudagraphs:ifself.config.get("triton.cudagraphs",True):fromtorch._inductor.cudagraph_treesimportreset_cudagraph_treesreset_cudagraph_trees()class_TorchCompileWrapper:def__init__(self,backend,mode,options,dynamic):fromtorch._dynamo.backends.registryimportlookup_backendifisinstance(backend,str):self.compiler_name=backendelifhasattr(backend,"__name__"):self.compiler_name=backend.__name__else:self.compiler_name=str(backend)self.dynamic=dynamicself.compiler_fn=lookup_backend(backend)self.kwargs={}# only pass the args if they non-emptyifmodeandmode!="default":self.kwargs["mode"]=modeifoptions:self.kwargs["options"]=optionsdef__eq__(self,other):return(isinstance(other,_TorchCompileWrapper)andself.compiler_fn==other.compiler_fnandself.kwargs==other.kwargsandself.dynamic==other.dynamic)def__call__(self,model_,inputs_):returnself.compiler_fn(model_,inputs_,**self.kwargs)defreset(self):ifhasattr(self.compiler_fn,"reset"):self.compiler_fn.reset()_InputT=_ParamSpec("_InputT")_RetT=_TypeVar("_RetT")@_overloaddefcompile(model:_Callable[_InputT,_RetT],*,fullgraph:builtins.bool=False,dynamic:_Optional[builtins.bool]=None,backend:_Union[str,_Callable]="inductor",mode:_Union[str,None]=None,options:_Optional[dict[str,_Union[str,builtins.int,builtins.bool]]]=None,disable:builtins.bool=False,)->_Callable[_InputT,_RetT]:...@_overloaddefcompile(model:None=None,*,fullgraph:builtins.bool=False,dynamic:_Optional[builtins.bool]=None,backend:_Union[str,_Callable]="inductor",mode:_Union[str,None]=None,options:_Optional[dict[str,_Union[str,builtins.int,builtins.bool]]]=None,disable:builtins.bool=False,)->_Callable[[_Callable[_InputT,_RetT]],_Callable[_InputT,_RetT]]:...
[docs]defcompile(model:_Optional[_Callable]=None,*,fullgraph:builtins.bool=False,dynamic:_Optional[builtins.bool]=None,backend:_Union[str,_Callable]="inductor",mode:_Union[str,None]=None,options:_Optional[dict[str,_Union[str,builtins.int,builtins.bool]]]=None,disable:builtins.bool=False,)->_Union[_Callable[[_Callable[_InputT,_RetT]],_Callable[_InputT,_RetT]],_Callable[_InputT,_RetT],]:""" Optimizes given model/function using TorchDynamo and specified backend. If you are compiling an :class:`torch.nn.Module`, you can also use :meth:`torch.nn.Module.compile` to compile the module inplace without changing its structure. Concretely, for every frame executed within the compiled region, we will attempt to compile it and cache the compiled result on the code object for future use. A single frame may be compiled multiple times if previous compiled results are not applicable for subsequent calls (this is called a "guard failure), you can use TORCH_LOGS=guards to debug these situations. Multiple compiled results can be associated with a frame up to ``torch._dynamo.config.recompile_limit``, which defaults to 8; at which point we will fall back to eager. Note that compile caches are per *code object*, not frame; if you dynamically create multiple copies of a function, they will all share the same code cache. Args: model (Callable): Module/function to optimize fullgraph (bool): If False (default), torch.compile attempts to discover compileable regions in the function that it will optimize. If True, then we require that the entire function be capturable into a single graph. If this is not possible (that is, if there are graph breaks), then this will raise an error. dynamic (bool or None): Use dynamic shape tracing. When this is True, we will up-front attempt to generate a kernel that is as dynamic as possible to avoid recompilations when sizes change. This may not always work as some operations/optimizations will force specialization; use TORCH_LOGS=dynamic to debug overspecialization. When this is False, we will NEVER generate dynamic kernels, we will always specialize. By default (None), we automatically detect if dynamism has occurred and compile a more dynamic kernel upon recompile. backend (str or Callable): backend to be used - "inductor" is the default backend, which is a good balance between performance and overhead - Non experimental in-tree backends can be seen with `torch._dynamo.list_backends()` - Experimental or debug in-tree backends can be seen with `torch._dynamo.list_backends(None)` - To register an out-of-tree custom backend: https://pytorch.org/docs/main/torch.compiler_custom_backends.html#registering-custom-backends mode (str): Can be either "default", "reduce-overhead", "max-autotune" or "max-autotune-no-cudagraphs" - "default" is the default mode, which is a good balance between performance and overhead - "reduce-overhead" is a mode that reduces the overhead of python with CUDA graphs, useful for small batches. Reduction of overhead can come at the cost of more memory usage, as we will cache the workspace memory required for the invocation so that we do not have to reallocate it on subsequent runs. Reduction of overhead is not guaranteed to work; today, we only reduce overhead for CUDA only graphs which do not mutate inputs. There are other circumstances where CUDA graphs are not applicable; use TORCH_LOG=perf_hints to debug. - "max-autotune" is a mode that leverages Triton or template based matrix multiplications on supported devices and Triton based convolutions on GPU. It enables CUDA graphs by default on GPU. - "max-autotune-no-cudagraphs" is a mode similar to "max-autotune" but without CUDA graphs - To see the exact configs that each mode sets you can call `torch._inductor.list_mode_options()` options (dict): A dictionary of options to pass to the backend. Some notable ones to try out are - `epilogue_fusion` which fuses pointwise ops into templates. Requires `max_autotune` to also be set - `max_autotune` which will profile to pick the best matmul configuration - `fallback_random` which is useful when debugging accuracy issues - `shape_padding` which pads matrix shapes to better align loads on GPUs especially for tensor cores - `triton.cudagraphs` which will reduce the overhead of python with CUDA graphs - `trace.enabled` which is the most useful debugging flag to turn on - `trace.graph_diagram` which will show you a picture of your graph after fusion - For inductor you can see the full list of configs that it supports by calling `torch._inductor.list_options()` disable (bool): Turn torch.compile() into a no-op for testing Example:: @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True) def foo(x): return torch.sin(x) + torch.cos(x) """importsysconfig_C._log_api_usage_once("torch.compile")ifsys.version_info>=(3,14):raiseRuntimeError("torch.compile is not supported on Python 3.14+")elifsysconfig.get_config_var("Py_GIL_DISABLED")==1:raiseRuntimeError("torch.compile is not supported on Python built with GIL disabled")# Decorator modeifmodelisNone:deffn(model:_Callable[_InputT,_RetT])->_Callable[_InputT,_RetT]:ifmodelisNone:raiseRuntimeError("Model can't be None")returncompile(model,fullgraph=fullgraph,dynamic=dynamic,backend=backend,mode=mode,options=options,disable=disable,)returnfnifmodeisnotNoneandoptionsisnotNone:raiseRuntimeError("Either mode or options can be specified, but both can't be specified at the same time.")ifmodeisNoneandoptionsisNone:mode="default"fromtorch._inductor.compiler_bisectorimportCompilerBisectorifbisect_backend:=CompilerBisector.get_backend():backend=bisect_backendifbackend=="inductor":backend=_TorchCompileInductorWrapper(mode,options,dynamic)else:backend=_TorchCompileWrapper(backend,mode,options,dynamic)returntorch._dynamo.optimize(backend=backend,nopython=fullgraph,dynamic=dynamic,disable=disable,)(model)# type: ignore[return-value]
def_register_device_module(device_type,module):r"""Register an external runtime module of the specific :attr:`device_type` supported by torch. After the :attr:`module` is registered correctly, the user can refer the external runtime module as part of torch with attribute torch.xxx. """# Make sure the device_type represent a supported device type for torch.device_type=torch.device(device_type).typem=sys.modules[__name__]ifhasattr(m,device_type):raiseRuntimeError(f"The runtime module of '{device_type}' has already "f"been registered with '{getattr(m,device_type)}'")setattr(m,device_type,module)torch_module_name=".".join([__name__,device_type])sys.modules[torch_module_name]=modulefromtorchimport(exportasexport,funcasfunc,libraryaslibrary,return_typesasreturn_types,)fromtorch._higher_order_opsimportcondascond,while_loopaswhile_loopfromtorch.funcimportvmapasvmapifnotTYPE_CHECKING:fromtorchimport_meta_registrations# Enable CUDA Sanitizerif"TORCH_CUDA_SANITIZER"inos.environ:importtorch.cuda._sanitizerascsancsan.enable_cuda_sanitizer()# Populate magic methods on SymInt and SymFloatimporttorch.fx.experimental.sym_nodefromtorchimportfxasfx# Register MPS specific decompstorch.backends.mps._init()ifnot_running_with_deploy():fromtorchimportcompilerascompilerclass_TritonLibrary:lib=torch.library.Library("triton","DEF")ops_table:dict[tuple[str,str],_Callable]={}@classmethoddefregisterOp(cls,op_key,full_schema,op_impl,dispatch_key):if(op_key,dispatch_key)notincls.ops_table:cls.lib.define(full_schema)cls.lib.impl("triton::"+op_key,op_impl,dispatch_key)cls.ops_table[(op_key,dispatch_key)]=op_implreturncls.ops_table[(op_key,dispatch_key)]# Deprecated attributes_deprecated_attrs={"has_mps":torch.backends.mps.is_built,"has_cuda":torch.backends.cuda.is_built,"has_cudnn":torch.backends.cudnn.is_available,"has_mkldnn":torch.backends.mkldnn.is_available,}ifTYPE_CHECKING:# Import the following modules during type checking to enable code intelligence features,# such as auto-completion in tools like pylance, even when these modules are not explicitly# imported in user code.fromtorchimport(_dynamoas_dynamo,_inductoras_inductor,_subclassesas_subclasses,onnxasonnx,)else:_lazy_modules={"_dynamo","_inductor","_export",# ONNX must be imported after _dynamo, _ops, _subclasses, fx, func and jit"onnx",}def__getattr__(name):# Deprecated attrsreplacement=_deprecated_attrs.get(name)ifreplacementisnotNone:importwarningswarnings.warn(f"'{name}' is deprecated, please use '{replacement.__module__}.{replacement.__name__}()'",stacklevel=2,)returnreplacement()# Lazy modulesifnamein_lazy_modules:returnimportlib.import_module(f".{name}",__name__)raiseAttributeError(f"module '{__name__}' has no attribute '{name}'")
[docs]defget_device_module(device:_Optional[_Union[torch.device,str]]=None):""" Returns the module associated with a given device(e.g., torch.device('cuda'), "mtia:0", "xpu", ...). If no device is given, return the module for the current accelerator or CPU if none is present. """ifisinstance(device,torch.device):device_module_name=device.typeelifisinstance(device,str):device_module_name=torch.device(device).typeelifdeviceisNone:# Using default accelerator type. If no accelerator is available, it automatically returns CPU device.device_module_name=torch._C._get_accelerator().typeelse:raiseRuntimeError(f"Invalid value of device '{device}', expect torch.device, str, or None")device_module=getattr(torch,device_module_name,None)ifdevice_moduleisNone:raiseRuntimeError(f"Device '{device_module_name}' does not have a corresponding module registered as 'torch.{device_module_name}'.")returndevice_module
def_constrain_as_size(symbol,min:_Optional[builtins.int]=None,max:_Optional[builtins.int]=None,):""" This indicates that a given int is size-like, and can be used in any context where a size is expected. You will typically use this when reading out integers from Tensors, e.g., max.item() or lengths.tolist() which then need to be used as tensor constructors. Providing these assertions to PyTorch can help resolve GuardOnDataDependentSymNode errors upon export, since we cannot guard on unbacked SymInts. This function has unusual semantics in some circumstances in framework code, we will treat this int as >= 2 (when we do a size-oblivious guard). This makes it easier to use the unbacked int in size contexts, as we will often attempt to guard on a size being zero/one (e.g., when computing the contiguity of a tensor, or testing if broadcasting can occur), which will not work on unbacked SymInts. However, if we conservatively assume that the size is not zero/one, we will end up with a graph that will still work even if the size is zero/one. For more details, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit ``` """torch.sym_constrain_range_for_size(symbol,min=min,max=max)fromtorchimport_logging_logging._init_logs()def_import_device_backends():""" Leverage the Python plugin mechanism to load out-of-the-tree device extensions. See this RFC: https://github.com/pytorch/pytorch/issues/122468 """fromimportlib.metadataimportentry_pointsgroup_name="torch.backends"ifsys.version_info<(3,10):backend_extensions=entry_points().get(group_name,())else:backend_extensions=entry_points(group=group_name)forbackend_extensioninbackend_extensions:try:# Load the extensionentrypoint=backend_extension.load()# Call the entrypointentrypoint()exceptExceptionaserr:raiseRuntimeError(f"Failed to load the backend extension: {backend_extension.name}. "f"You can disable extension auto-loading with TORCH_DEVICE_BACKEND_AUTOLOAD=0.")fromerrdef_is_device_backend_autoload_enabled()->builtins.bool:""" Whether autoloading out-of-the-tree device extensions is enabled. The switch depends on the value of the environment variable `TORCH_DEVICE_BACKEND_AUTOLOAD`. Returns: bool: Whether to enable autoloading the extensions. Enabled by default. Examples: >>> torch._is_device_backend_autoload_enabled() True """# enabled by defaultreturnos.getenv("TORCH_DEVICE_BACKEND_AUTOLOAD","1")=="1"def_as_tensor_fullprec(t):""" Like torch.as_tensor, but when given Python data types it will keep them in full precision. Used for calling convention for Dynamo. """ty=type(t)iftyisbuiltins.float:returntorch.as_tensor(t,dtype=torch.float64)eliftyisbuiltins.int:returntorch.as_tensor(t,dtype=torch.int64)else:returntorch.as_tensor(t)# `_import_device_backends` should be kept at the end to ensure# all the other functions in this module that may be accessed by# an autoloaded backend are definedif_is_device_backend_autoload_enabled():_import_device_backends()
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.