# mypy: allow-untyped-defs""" This module contains functions and classes that alter the behavior of torch.nn.functional.scaled_dot_product_attention """importcontextlibfromcollections.abcimportIterablefromtypingimportUnionfromwarningsimportwarnimporttorch.backends.cudafromtorch._Cimport_SDPBackendasSDPBackendfromtorch.backends.cudaimport(can_use_efficient_attention,can_use_flash_attention,SDPAParams,)__all__:list[str]=["SDPBackend","sdpa_kernel","WARN_FOR_UNFUSED_KERNELS"]# Note: [SDPA warnings]# TODO: Consider using this for sdpa regardless of subclasses# This only effects users of bias subclasses# If this is set to True, we will warn the user if they are not using the fused kernels# As well, it will raise warnings for all the reasons why the fused kernels can't be run.# To set this to True, run# torch.nn.attention.WARN_FOR_UNFUSED_KERNELS = TrueWARN_FOR_UNFUSED_KERNELS=False# Hacks for Sphinx documentation:# https://stackoverflow.com/questions/38765577/overriding-sphinx-autodoc-alias-of-for-import-of-private-classSDPBackend=SDPBackendr"""An enum-like class that contains the different backends for scaled dot product attention. This backend class is designed to be used with the sdpa_kernel context manager. The following Enums are available: - ERROR: An error occurred when trying to determine the backend. - MATH: The math backend for scaled dot product attention. - FLASH_ATTENTION: The flash attention backend for scaled dot product attention. - EFFICIENT_ATTENTION: The efficient attention backend for scaled dot product attention. - CUDNN_ATTENTION: The cuDNN backend for scaled dot product attention. See :func:`torch.nn.attention.sdpa_kernel` for more details. .. warning:: This class is in beta and subject to change."""SDPBackend.__module__=__name__SDPBackend.__name__="SDPBackend"def_raise_kernel_warnings(params:SDPAParams)->None:""" If WARN_FOR_UNFUSED_KERNELS is set to True, this will raise warnings for all the reasons why the fused kernels can't be run. If using subclasses """ifWARN_FOR_UNFUSED_KERNELS:ifnotcan_use_efficient_attention(params):warn("Efficient attention can't be used because:")can_use_efficient_attention(params,True)ifnotcan_use_flash_attention(params):warn("Flash attention can't be used because:")can_use_flash_attention(params,True)_backend_names={"cudnn":"CUDNN_ATTENTION","flash":"FLASH_ATTENTION","mem_efficient":"EFFICIENT_ATTENTION","math":"MATH",}def_backend_from_string(name:str):returngetattr(SDPBackend,name)def_cur_sdpa_kernel_backends(with_priority:bool=False):backends=[]forname,valin_backend_names.items():ifgetattr(torch.backends.cuda,f"{name}_sdp_enabled")():backends.append(getattr(SDPBackend,val))ifwith_priority:curr_priority=torch._C._get_sdp_priority_order()backends=sorted(backends,key=lambdabackend:curr_priority.index(int(backend)))returnbackendsdef_sdpa_kernel(backends:Iterable,set_priority:bool=False):forname,valin_backend_names.items():enabled=getattr(SDPBackend,val)inbackendsgetattr(torch.backends.cuda,f"enable_{name}_sdp")(enabled)ifset_priority:# backends should be a unique listuser_priority=[int(backend)forbackendinbackends]previous_priority=torch._C._get_sdp_priority_order()forbackendinprevious_priority:ifbackendnotinuser_priority:user_priority.append(int(backend))torch._C._set_sdp_priority_order(user_priority)
[docs]@contextlib.contextmanagerdefsdpa_kernel(backends:Union[list[SDPBackend],SDPBackend],set_priority:bool=False):r""" Context manager to select which backend to use for scaled dot product attention. .. warning:: This function is beta and subject to change. Args: backends (Union[List[SDPBackend], SDPBackend]): A backend or list of backends for scaled dot product attention. set_priority_order (bool=False): Whether the ordering of the backends is interpreted as their priority order. Example: .. code-block:: python from torch.nn.functional import scaled_dot_product_attention from torch.nn.attention import SDPBackend, sdpa_kernel # Only enable flash attention backend with sdpa_kernel(SDPBackend.FLASH_ATTENTION): scaled_dot_product_attention(...) # Enable the Math or Efficient attention backends with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]): scaled_dot_product_attention(...) This context manager can be used to select which backend to use for scaled dot product attention. Upon exiting the context manager, the previous state of the flags will be restored, enabling all backends. """assertisinstance(backends,(list,SDPBackend)),"Backend must be an instance of SDPBackend or a list of SDPBackend instances"ifisinstance(backends,SDPBackend):backends=[backends]backends=list(dict.fromkeys(backends))previous_backends=_cur_sdpa_kernel_backends(with_priority=set_priority)try:_sdpa_kernel(backends,set_priority)yield{}finally:_sdpa_kernel(previous_backends,set_priority)
# variadic version of sdpa_kernel for dynamo to use while reconstructing@contextlib.contextmanagerdef_sdpa_kernel_variadic(*backends:SDPBackend):withsdpa_kernel(list(backends)):yielddef_get_flash_version()->str:"""This returns the closest matching tag for the flash attention backend"""return"2.5.7"
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.