[docs]defis_built():r""" Return whether PyTorch is built with CUDA support. Note that this doesn't necessarily mean CUDA is available; just that if this PyTorch binary were run on a machine with working CUDA drivers and devices, we would be able to use it. """returntorch._C._has_cuda
classcuFFTPlanCacheAttrContextProp:# Like regular ContextProp, but uses the `.device_index` attribute from the# calling object as the first argument to the getter and setter.def__init__(self,getter,setter):self.getter=getterself.setter=setterdef__get__(self,obj,objtype):returnself.getter(obj.device_index)def__set__(self,obj,val):ifisinstance(self.setter,str):raiseRuntimeError(self.setter)self.setter(obj.device_index,val)classcuFFTPlanCache:r""" Represent a specific plan cache for a specific `device_index`. The attributes `size` and `max_size`, and method `clear`, can fetch and/ or change properties of the C++ cuFFT plan cache. """def__init__(self,device_index):self.device_index=device_indexsize=cuFFTPlanCacheAttrContextProp(torch._cufft_get_plan_cache_size,".size is a read-only property showing the number of plans currently in the ""cache. To change the cache capacity, set cufft_plan_cache.max_size.",)max_size=cuFFTPlanCacheAttrContextProp(torch._cufft_get_plan_cache_max_size,torch._cufft_set_plan_cache_max_size)defclear(self):returntorch._cufft_clear_plan_cache(self.device_index)classcuFFTPlanCacheManager:r""" Represent all cuFFT plan caches, return the cuFFTPlanCache for a given device when indexed. Finally, this object, when used directly as a `cuFFTPlanCache` object (e.g., setting the `.max_size`) attribute, the current device's cuFFT plan cache is used. """__initialized=Falsedef__init__(self):self.caches=[]self.__initialized=Truedef__getitem__(self,device):index=torch.cuda._utils._get_device_index(device)ifindex<0orindex>=torch.cuda.device_count():raiseRuntimeError(f"cufft_plan_cache: expected 0 <= device index < {torch.cuda.device_count()}, but got "f"device with index {index}")iflen(self.caches)==0:self.caches.extend(cuFFTPlanCache(index)forindexinrange(torch.cuda.device_count()))returnself.caches[index]def__getattr__(self,name):returngetattr(self[torch.cuda.current_device()],name)def__setattr__(self,name,value):ifself.__initialized:returnsetattr(self[torch.cuda.current_device()],name,value)else:returnsuper().__setattr__(name,value)classcuBLASModule:def__getattr__(self,name):ifname=="allow_tf32":returntorch._C._get_cublas_allow_tf32()elifname=="allow_fp16_reduced_precision_reduction":returntorch._C._get_cublas_allow_fp16_reduced_precision_reduction()elifname=="allow_bf16_reduced_precision_reduction":returntorch._C._get_cublas_allow_bf16_reduced_precision_reduction()raiseAttributeError("Unknown attribute "+name)def__setattr__(self,name,value):ifname=="allow_tf32":returntorch._C._set_cublas_allow_tf32(value)elifname=="allow_fp16_reduced_precision_reduction":returntorch._C._set_cublas_allow_fp16_reduced_precision_reduction(value)elifname=="allow_bf16_reduced_precision_reduction":returntorch._C._set_cublas_allow_bf16_reduced_precision_reduction(value)raiseAttributeError("Unknown attribute "+name)_LinalgBackends={"default":torch._C._LinalgBackend.Default,"cusolver":torch._C._LinalgBackend.Cusolver,"magma":torch._C._LinalgBackend.Magma,}_LinalgBackends_str=", ".join(_LinalgBackends.keys())
[docs]defpreferred_linalg_library(backend:Union[None,str,torch._C._LinalgBackend]=None)->torch._C._LinalgBackend:r""" Override the heuristic PyTorch uses to choose between cuSOLVER and MAGMA for CUDA linear algebra operations. .. warning:: This flag is experimental and subject to change. When PyTorch runs a CUDA linear algebra operation it often uses the cuSOLVER or MAGMA libraries, and if both are available it decides which to use with a heuristic. This flag (a :class:`str`) allows overriding those heuristics. * If `"cusolver"` is set then cuSOLVER will be used wherever possible. * If `"magma"` is set then MAGMA will be used wherever possible. * If `"default"` (the default) is set then heuristics will be used to pick between cuSOLVER and MAGMA if both are available. * When no input is given, this function returns the currently preferred library. * User may use the environment variable TORCH_LINALG_PREFER_CUSOLVER=1 to set the preferred library to cuSOLVER globally. This flag only sets the initial value of the preferred library and the preferred library may still be overridden by this function call later in your script. Note: When a library is preferred other libraries may still be used if the preferred library doesn't implement the operation(s) called. This flag may achieve better performance if PyTorch's heuristic library selection is incorrect for your application's inputs. Currently supported linalg operators: * :func:`torch.linalg.inv` * :func:`torch.linalg.inv_ex` * :func:`torch.linalg.cholesky` * :func:`torch.linalg.cholesky_ex` * :func:`torch.cholesky_solve` * :func:`torch.cholesky_inverse` * :func:`torch.linalg.lu_factor` * :func:`torch.linalg.lu` * :func:`torch.linalg.lu_solve` * :func:`torch.linalg.qr` * :func:`torch.linalg.eigh` * :func:`torch.linalg.eighvals` * :func:`torch.linalg.svd` * :func:`torch.linalg.svdvals` """ifbackendisNone:passelifisinstance(backend,str):ifbackendnotin_LinalgBackends:raiseRuntimeError("Unknown input value. "f"Choose from: {_LinalgBackends_str}.")torch._C._set_linalg_preferred_backend(_LinalgBackends[backend])elifisinstance(backend,torch._C._LinalgBackend):torch._C._set_linalg_preferred_backend(backend)else:raiseRuntimeError("Unknown input value type.")returntorch._C._get_linalg_preferred_backend()
[docs]defpreferred_blas_library(backend:Union[None,str,torch._C._BlasBackend]=None)->torch._C._BlasBackend:r""" Override the library PyTorch uses for BLAS operations. Choose between cuBLAS, cuBLASLt, and CK [ROCm-only]. .. warning:: This flag is experimental and subject to change. When PyTorch runs a CUDA BLAS operation it defaults to cuBLAS even if both cuBLAS and cuBLASLt are available. For PyTorch built for ROCm, hipBLAS, hipBLASLt, and CK may offer different performance. This flag (a :class:`str`) allows overriding which BLAS library to use. * If `"cublas"` is set then cuBLAS will be used wherever possible. * If `"cublaslt"` is set then cuBLASLt will be used wherever possible. * If `"ck"` is set then CK will be used wherever possible. * When no input is given, this function returns the currently preferred library. * User may use the environment variable TORCH_BLAS_PREFER_CUBLASLT=1 to set the preferred library to cuBLASLt globally. This flag only sets the initial value of the preferred library and the preferred library may still be overridden by this function call later in your script. Note: When a library is preferred other libraries may still be used if the preferred library doesn't implement the operation(s) called. This flag may achieve better performance if PyTorch's library selection is incorrect for your application's inputs. """ifbackendisNone:passelifisinstance(backend,str):ifbackendnotin_BlasBackends:raiseRuntimeError("Unknown input value. "f"Choose from: {_BlasBackends_str}.")torch._C._set_blas_preferred_backend(_BlasBackends[backend])elifisinstance(backend,torch._C._BlasBackend):torch._C._set_blas_preferred_backend(backend)else:raiseRuntimeError("Unknown input value type.")returntorch._C._get_blas_preferred_backend()
fromtorch._Cimport_SDPAParamsasSDPAParams,_SDPBackendasSDPBackend# Set the __module__ attributeSDPAParams.__module__="torch.backends.cuda"SDPAParams.__name__="SDPAParams"
[docs]defflash_sdp_enabled():r""" .. warning:: This flag is beta and subject to change. Returns whether flash scaled dot product attention is enabled or not. """returntorch._C._get_flash_sdp_enabled()
[docs]defenable_flash_sdp(enabled:bool):r""" .. warning:: This flag is beta and subject to change. Enables or disables flash scaled dot product attention. """torch._C._set_sdp_use_flash(enabled)
[docs]defmem_efficient_sdp_enabled():r""" .. warning:: This flag is beta and subject to change. Returns whether memory efficient scaled dot product attention is enabled or not. """returntorch._C._get_mem_efficient_sdp_enabled()
[docs]defenable_mem_efficient_sdp(enabled:bool):r""" .. warning:: This flag is beta and subject to change. Enables or disables memory efficient scaled dot product attention. """torch._C._set_sdp_use_mem_efficient(enabled)
[docs]defmath_sdp_enabled():r""" .. warning:: This flag is beta and subject to change. Returns whether math scaled dot product attention is enabled or not. """returntorch._C._get_math_sdp_enabled()
[docs]defenable_math_sdp(enabled:bool):r""" .. warning:: This flag is beta and subject to change. Enables or disables math scaled dot product attention. """torch._C._set_sdp_use_math(enabled)
[docs]defallow_fp16_bf16_reduction_math_sdp(enabled:bool):r""" .. warning:: This flag is beta and subject to change. Enables or disables fp16/bf16 reduction in math scaled dot product attention. """torch._C._set_math_sdp_allow_fp16_bf16_reduction(enabled)
[docs]deffp16_bf16_reduction_math_sdp_allowed():r""" .. warning:: This flag is beta and subject to change. Returns whether fp16/bf16 reduction in math scaled dot product attention is enabled or not. """returntorch._C._get_math_sdp_allow_fp16_bf16_reduction()
[docs]defis_flash_attention_available()->bool:r"""Check if PyTorch was built with FlashAttention for scaled_dot_product_attention. Returns: True if FlashAttention is built and available; otherwise, False. Note: This function is dependent on a CUDA-enabled build of PyTorch. It will return False in non-CUDA environments. """returntorch._C._is_flash_attention_available()
[docs]defcan_use_flash_attention(params:SDPAParams,debug:bool=False)->bool:r"""Check if FlashAttention can be utilized in scaled_dot_product_attention. Args: params: An instance of SDPAParams containing the tensors for query, key, value, an optional attention mask, dropout rate, and a flag indicating if the attention is causal. debug: Whether to logging.warn debug information as to why FlashAttention could not be run. Defaults to False. Returns: True if FlashAttention can be used with the given parameters; otherwise, False. Note: This function is dependent on a CUDA-enabled build of PyTorch. It will return False in non-CUDA environments. """returntorch._C._can_use_flash_attention(params,debug)
[docs]defcan_use_efficient_attention(params:SDPAParams,debug:bool=False)->bool:r"""Check if efficient_attention can be utilized in scaled_dot_product_attention. Args: params: An instance of SDPAParams containing the tensors for query, key, value, an optional attention mask, dropout rate, and a flag indicating if the attention is causal. debug: Whether to logging.warn with information as to why efficient_attention could not be run. Defaults to False. Returns: True if efficient_attention can be used with the given parameters; otherwise, False. Note: This function is dependent on a CUDA-enabled build of PyTorch. It will return False in non-CUDA environments. """returntorch._C._can_use_mem_efficient_attention(params,debug)
[docs]defcan_use_cudnn_attention(params:SDPAParams,debug:bool=False)->bool:r"""Check if cudnn_attention can be utilized in scaled_dot_product_attention. Args: params: An instance of SDPAParams containing the tensors for query, key, value, an optional attention mask, dropout rate, and a flag indicating if the attention is causal. debug: Whether to logging.warn with information as to why cuDNN attention could not be run. Defaults to False. Returns: True if cuDNN can be used with the given parameters; otherwise, False. Note: This function is dependent on a CUDA-enabled build of PyTorch. It will return False in non-CUDA environments. """returntorch._C._can_use_cudnn_attention(params,debug)
[docs]defcudnn_sdp_enabled():r""" .. warning:: This flag is beta and subject to change. Returns whether cuDNN scaled dot product attention is enabled or not. """returntorch._C._get_cudnn_sdp_enabled()
[docs]defenable_cudnn_sdp(enabled:bool):r""" .. warning:: This flag is beta and subject to change. Enables or disables cuDNN scaled dot product attention. """torch._C._set_sdp_use_cudnn(enabled)
[docs]@contextlib.contextmanager@deprecated(("`torch.backends.cuda.sdp_kernel()` is deprecated. ""In the future, this context manager will be removed. ""Please see `torch.nn.attention.sdpa_kernel()` for the new context manager, ""with updated signature."),category=FutureWarning,)defsdp_kernel(enable_flash:bool=True,enable_math:bool=True,enable_mem_efficient:bool=True,enable_cudnn:bool=True,):r""" .. warning:: This flag is beta and subject to change. This context manager can be used to temporarily enable or disable any of the three backends for scaled dot product attention. Upon exiting the context manager, the previous state of the flags will be restored. """fromtorch.nn.attentionimportsdpa_kernelbackend_list=[]ifenable_flash:backend_list.append(SDPBackend.FLASH_ATTENTION)ifenable_mem_efficient:backend_list.append(SDPBackend.EFFICIENT_ATTENTION)ifenable_math:backend_list.append(SDPBackend.MATH)ifenable_cudnn:backend_list.append(SDPBackend.CUDNN_ATTENTION)withsdpa_kernel(backend_list)ascontext:try:yieldcontextfinally:pass
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.