# mypy: allow-untyped-defs"""Defines bias subclasses that work with scaled_dot_product_attention"""fromenumimportauto,IntEnumfromtypingimportOptionalfromwarningsimportwarnimporttorchimporttorch.nn.functionalasFfromtorch.backends.cudaimport(can_use_efficient_attention,can_use_flash_attention,is_flash_attention_available,SDPAParams,)fromtorch.nn.attentionimport_raise_kernel_warningsfromtorch.nn.attention._utilsimport(_calculate_scale,_input_requires_grad,_postprocess_flash_output,_validate_sdpa_input,)__all__=["causal_upper_left","causal_lower_right","CausalVariant","CausalBias"]torch._dynamo.allow_in_graph(is_flash_attention_available)torch._dynamo.allow_in_graph(can_use_flash_attention)torch._dynamo.allow_in_graph(can_use_efficient_attention)torch._dynamo.allow_in_graph(SDPAParams)
[docs]classCausalVariant(IntEnum):r""" Enum for causal variants used in attention mechanisms. Defines two types of causal biases: `UPPER_LEFT`: Represents upper-left triangular bias for standard causal attention. The equivalent pytorch code for constructing this bias is: .. code-block:: python torch.tril(torch.ones(size, dtype=torch.bool)) For instance, with `shape=(3,4)`, the materialized bias tensor will be: .. code-block:: text [[1, 0, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0]] `LOWER_RIGHT`: Represents lower-right triangular bias, the include values are aligned to the lower right corner of the matrix. The equivalent pytorch code for constructing this bias is: .. code-block:: python diagonal_offset = size[1] - size[0] torch.tril( torch.ones(size, dtype=torch.bool), diagonal=diagonal_offset, ) For instance, with `shape=(3,4)`, the materialized bias tensor will be: .. code-block:: text [[1, 1, 0, 0], [1, 1, 1, 0], [1, 1, 1, 1]] Note that these variants are equivalent to each other when the sequence lengths of the query and key/value tensors are equal since the triangular matrix is square. .. warning:: This enum is a prototype and subject to change. """UPPER_LEFT=auto()LOWER_RIGHT=auto()
[docs]classCausalBias(torch.Tensor):""" A bias representing causal attention patterns. For an overview of the bias structure, see the :class:`CausalVariant` enum. This class is used for defining causal (triangular) attention biases. For construing the bias, there exist two factory functions: :func:`causal_upper_left` and :func:`causal_lower_right`. Example: .. code-block:: python from torch.nn.attention.bias import causal_lower_right bsz, num_heads, seqlen_q, seqlen_kv, head_dim = 32, 8, 4, 12, 8 # Create a lower-right causal bias attn_bias = causal_lower_right(seqlen_q, seqlen_kv) q = torch.randn(bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16) k = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16) v = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16) out = F.scaled_dot_product_attention(q, k, v, attn_bias) .. warning:: This class is a prototype and subject to change. """def__init__(self,variant:CausalVariant,seq_len_q:int,seq_len_kv:int):""" Initializes the CausalBias instance with a specified variant and sequence lengths. Args: variant (CausalVariant): The type of causal bias to use (either UPPER_LEFT or LOWER_RIGHT). seq_len_q (int): The sequence length of the query tensor. seq_len_kv (int): The sequence length of the key/value tensor. Raises a warning if the LOWER_RIGHT variant is used with seq_len_q > seq_len_kv, as it may produce NaNs. """assertisinstance(variant,CausalVariant)self.variant=variantself.seq_len_q=seq_len_qself.seq_len_kv=seq_len_kvifseq_len_q>seq_len_kvandvariant==CausalVariant.LOWER_RIGHT:warn("Lower right causal bias will produce NaNs in the output when seq_len_q > seq_len_kv!")def_upper_left(self,device:torch.device)->torch.Tensor:"""Upper left causal bias"""returntorch.tril(torch.ones(self.seq_len_q,self.seq_len_kv,device=device,dtype=torch.bool))def_lower_right(self,device:torch.device)->torch.Tensor:"""Lower right causal bias"""diagonal_offset=self.seq_len_kv-self.seq_len_qreturntorch.tril(torch.ones(self.seq_len_q,self.seq_len_kv,device=device,dtype=torch.bool),diagonal=diagonal_offset,)def_materialize(self,device:Optional[torch.device]=None)->torch.Tensor:""" Materializes the causal bias into a tensor form. Depending on the variant, this method generates either an upper-left or lower-right triangular matrix to represent the causal bias. Args: device (Optional[torch.device]): The device on which to create the tensor. Defaults to CPU. Returns: torch.Tensor: The materialized bias tensor. """ifdeviceisNone:device=torch.device("cpu")ifself.variant==CausalVariant.UPPER_LEFT:returnself._upper_left(device)elifself.variant==CausalVariant.LOWER_RIGHT:returnself._lower_right(device)@staticmethoddef_dispatch(query:torch.Tensor,key:torch.Tensor,value:torch.Tensor,attn_mask:"CausalBias",dropout_p:float=0.0,is_causal:bool=False,scale:Optional[float]=None,enable_gqa:bool=False,)->torch.Tensor:r""" Handles the logic for computing attention with the specified causal bias. Args: query (Tensor): Query tensor; shape :math:`(N, ..., L, E)`. key (Tensor): Key tensor; shape :math:`(N, ..., S, E)`. value (Tensor): Value tensor; shape :math:`(N, ..., S, Ev)`. attn_mask (CausalBias): The type of causal attention to apply. A boolean mask where a value of True indicates that the element *should* take part in attention. A float mask of the same type as query, key, value that is added to the attention score. dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied is_causal (bool): If true, assumes upper left causal attention masking and errors if both attn_mask and is_causal are set. scale (optional float): Scaling factor applied prior to softmax. If None, the default value is set to :math:`\frac{1}{\sqrt{E}}`. enable_gqa (optional bool): If set to True, Grouped Query Attention (GQA) is enabled, by default it is set to False. Returns: output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`. Raises: ValueError: If the causal bias variant is not a CausalVariant type. """ifis_causal:raiseValueError("CausalBias should not be used with causal=True")if(attn_mask.seq_len_q==attn_mask.seq_len_kvorattn_mask.variant==CausalVariant.UPPER_LEFT):returnF.scaled_dot_product_attention(query,key,value,attn_mask=None,dropout_p=dropout_p,is_causal=True,scale=scale,enable_gqa=enable_gqa,)elifattn_mask.variant==CausalVariant.LOWER_RIGHT:_validate_sdpa_input(query,key,value,None,dropout_p,is_causal,scale)sdpa_params=SDPAParams(query,key,value,None,dropout_p,is_causal,enable_gqa)ifcan_use_flash_attention(sdpa_params):needs_padding=query.size(-1)%8!=0og_head_size=query.size(-1)og_scale=_calculate_scale(og_head_size,scale)ifneeds_padding:query=torch.nn.functional.pad(query,(0,8-query.size(-1)%8))key=torch.nn.functional.pad(key,(0,8-key.size(-1)%8))value=torch.nn.functional.pad(value,(0,8-value.size(-1)%8))out=torch.ops.aten._scaled_dot_product_flash_attention(query,key,value,dropout_p,is_causal=True,# TODO: Flash accepts causal = True and for this particular op it means lower rightreturn_debug_mask=False,scale=og_scale,)[0]return_postprocess_flash_output(out,og_head_size)ifcan_use_efficient_attention(sdpa_params):compute_log_sumexp=Falseif_input_requires_grad(query,key,value):compute_log_sumexp=Truereturntorch.ops.aten._efficient_attention_forward(query.transpose(1,2),key.transpose(1,2),value.transpose(1,2),bias=None,cu_seqlens_q=None,cu_seqlens_k=None,max_seqlen_q=None,max_seqlen_k=None,dropout_p=dropout_p,custom_mask_type=int(attn_mask.variant),compute_log_sumexp=compute_log_sumexp,scale=scale,seqlen_k=None,)[0].transpose(1,2)else:_raise_kernel_warnings(sdpa_params)# We cant use efficient attention the only support for lower right is via materializationreturnF.scaled_dot_product_attention(query,key,value,attn_mask=attn_mask._materialize(query.device),dropout_p=dropout_p,is_causal=False,scale=scale,enable_gqa=enable_gqa,)else:raiseValueError(f"CausalBias.variant must be a CausalVariant type, but found: {attn_mask.variant}")@classmethoddef__torch_function__(cls,func,types,args=(),kwargs=None):"""Defines the behavior of torch.nn.functional.scaled_dot_product_attention when the attn_bias is an AttnBias"""ifkwargsisNone:kwargs={}iffunc!=torch.nn.functional.scaled_dot_product_attention:raiseNotImplementedError("CausalBias only supports scaled_dot_product_attention")returncls._dispatch(*args,**kwargs)def__repr__(self):# type:ignore[override]returnself._materialize().__repr__()
[docs]defcausal_upper_left(*size)->CausalBias:""" Creates an upper-left triangular causal bias. This function generates a upper-left triangular matrix to represent causal attention bias with a diagonal offset set so that the inclusive values are aligned to the upper left corner of the matrix. This equivalent to the `is_causal=True` argument in `scaled_dot_product_attention`. The equivalent pytorch code for constructing this bias is: .. code-block:: python torch.tril(torch.ones(size, dtype=torch.bool)) For instance, with `shape=(3,4)`, the materialized bias tensor will be: .. code-block:: text [[1, 0, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0]] Args: size: The size of the bias matrix. Returns: CausalBias: The UPPER_LEFT triangular causal bias variant. """assertlen(size)==2,"causal_upper_left only supports 2D tensors"seq_len_q,seq_len_kv=sizereturnCausalBias(CausalVariant.UPPER_LEFT,seq_len_q,seq_len_kv)
[docs]defcausal_lower_right(*size)->CausalBias:""" Creates a lower-right triangular causal bias. This function generates a lower-right triangular matrix to represent causal attention bias with a diagonal offset set so that the inclusive values are aligned to the lower right corner of the matrix. The equivalent pytorch code for constructing this bias is: .. code-block:: python diagonal_offset = size[1] - size[0] torch.tril( torch.ones(size, dtype=torch.bool), diagonal=diagonal_offset, ) For instance, with `shape=(3,4)`, the materialized bias tensor will be: .. code-block:: text [[1, 1, 0, 0], [1, 1, 1, 0], [1, 1, 1, 1]] Args: size: The size of the bias matrix. Returns: CausalBias: The LOWER_RIGHT triangular causal bias variant. """assertlen(size)==2,"causal_lower_right only supports 2D tensors"seq_len_q,seq_len_kv=sizereturnCausalBias(CausalVariant.LOWER_RIGHT,seq_len_q,seq_len_kv)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.