[docs]classThreshold(Module):r"""Thresholds each element of the input Tensor. Threshold is defined as: .. math:: y = \begin{cases} x, &\text{ if } x > \text{threshold} \\ \text{value}, &\text{ otherwise } \end{cases} Args: threshold: The value to threshold at value: The value to replace with inplace: can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. Examples:: >>> m = nn.Threshold(0.1, 20) >>> input = torch.randn(2) >>> output = m(input) """__constants__=['threshold','value','inplace']threshold:floatvalue:floatinplace:booldef__init__(self,threshold:float,value:float,inplace:bool=False)->None:super().__init__()self.threshold=thresholdself.value=valueself.inplace=inplace# TODO: check in THNN (if inplace == True, then assert value <= threshold)defforward(self,input:Tensor)->Tensor:returnF.threshold(input,self.threshold,self.value,self.inplace)defextra_repr(self):inplace_str=', inplace=True'ifself.inplaceelse''returnf'threshold={self.threshold}, value={self.value}{inplace_str}'
[docs]classReLU(Module):r"""Applies the rectified linear unit function element-wise: :math:`\text{ReLU}(x) = (x)^+ = \max(0, x)` Args: inplace: can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. .. image:: ../scripts/activation_images/ReLU.png Examples:: >>> m = nn.ReLU() >>> input = torch.randn(2) >>> output = m(input) An implementation of CReLU - https://arxiv.org/abs/1603.05201 >>> m = nn.ReLU() >>> input = torch.randn(2).unsqueeze(0) >>> output = torch.cat((m(input), m(-input))) """__constants__=['inplace']inplace:booldef__init__(self,inplace:bool=False):super().__init__()self.inplace=inplacedefforward(self,input:Tensor)->Tensor:returnF.relu(input,inplace=self.inplace)defextra_repr(self)->str:inplace_str='inplace=True'ifself.inplaceelse''returninplace_str
[docs]classRReLU(Module):r"""Applies the randomized leaky rectified liner unit function, element-wise, as described in the paper: `Empirical Evaluation of Rectified Activations in Convolutional Network`_. The function is defined as: .. math:: \text{RReLU}(x) = \begin{cases} x & \text{if } x \geq 0 \\ ax & \text{ otherwise } \end{cases} where :math:`a` is randomly sampled from uniform distribution :math:`\mathcal{U}(\text{lower}, \text{upper})` during training while during evaluation :math:`a` is fixed with :math:`a = \frac{\text{lower} + \text{upper}}{2}`. See: https://arxiv.org/pdf/1505.00853.pdf Args: lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}` upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}` inplace: can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. .. image:: ../scripts/activation_images/RReLU.png Examples:: >>> m = nn.RReLU(0.1, 0.3) >>> input = torch.randn(2) >>> output = m(input) .. _`Empirical Evaluation of Rectified Activations in Convolutional Network`: https://arxiv.org/abs/1505.00853 """__constants__=['lower','upper','inplace']lower:floatupper:floatinplace:booldef__init__(self,lower:float=1./8,upper:float=1./3,inplace:bool=False):super().__init__()self.lower=lowerself.upper=upperself.inplace=inplacedefforward(self,input:Tensor)->Tensor:returnF.rrelu(input,self.lower,self.upper,self.training,self.inplace)defextra_repr(self):inplace_str=', inplace=True'ifself.inplaceelse''returnf'lower={self.lower}, upper={self.upper}{inplace_str}'
[docs]classHardtanh(Module):r"""Applies the HardTanh function element-wise. HardTanh is defined as: .. math:: \text{HardTanh}(x) = \begin{cases} \text{max\_val} & \text{ if } x > \text{ max\_val } \\ \text{min\_val} & \text{ if } x < \text{ min\_val } \\ x & \text{ otherwise } \\ \end{cases} Args: min_val: minimum value of the linear region range. Default: -1 max_val: maximum value of the linear region range. Default: 1 inplace: can optionally do the operation in-place. Default: ``False`` Keyword arguments :attr:`min_value` and :attr:`max_value` have been deprecated in favor of :attr:`min_val` and :attr:`max_val`. Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. .. image:: ../scripts/activation_images/Hardtanh.png Examples:: >>> m = nn.Hardtanh(-2, 2) >>> input = torch.randn(2) >>> output = m(input) """__constants__=['min_val','max_val','inplace']min_val:floatmax_val:floatinplace:booldef__init__(self,min_val:float=-1.,max_val:float=1.,inplace:bool=False,min_value:Optional[float]=None,max_value:Optional[float]=None)->None:super().__init__()ifmin_valueisnotNone:warnings.warn("keyword argument min_value is deprecated and rename to min_val")min_val=min_valueifmax_valueisnotNone:warnings.warn("keyword argument max_value is deprecated and rename to max_val")max_val=max_valueself.min_val=min_valself.max_val=max_valself.inplace=inplaceassertself.max_val>self.min_valdefforward(self,input:Tensor)->Tensor:returnF.hardtanh(input,self.min_val,self.max_val,self.inplace)defextra_repr(self)->str:inplace_str=', inplace=True'ifself.inplaceelse''returnf'min_val={self.min_val}, max_val={self.max_val}{inplace_str}'
[docs]classReLU6(Hardtanh):r"""Applies the element-wise function: .. math:: \text{ReLU6}(x) = \min(\max(0,x), 6) Args: inplace: can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. .. image:: ../scripts/activation_images/ReLU6.png Examples:: >>> m = nn.ReLU6() >>> input = torch.randn(2) >>> output = m(input) """def__init__(self,inplace:bool=False):super().__init__(0.,6.,inplace)defextra_repr(self)->str:inplace_str='inplace=True'ifself.inplaceelse''returninplace_str
[docs]classSigmoid(Module):r"""Applies the element-wise function: .. math:: \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)} Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. .. image:: ../scripts/activation_images/Sigmoid.png Examples:: >>> m = nn.Sigmoid() >>> input = torch.randn(2) >>> output = m(input) """defforward(self,input:Tensor)->Tensor:returntorch.sigmoid(input)
[docs]classHardsigmoid(Module):r"""Applies the Hardsigmoid function element-wise. Hardsigmoid is defined as: .. math:: \text{Hardsigmoid}(x) = \begin{cases} 0 & \text{if~} x \le -3, \\ 1 & \text{if~} x \ge +3, \\ x / 6 + 1 / 2 & \text{otherwise} \end{cases} Args: inplace: can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. .. image:: ../scripts/activation_images/Hardsigmoid.png Examples:: >>> m = nn.Hardsigmoid() >>> input = torch.randn(2) >>> output = m(input) """__constants__=['inplace']inplace:booldef__init__(self,inplace:bool=False)->None:super().__init__()self.inplace=inplacedefforward(self,input:Tensor)->Tensor:returnF.hardsigmoid(input,self.inplace)
[docs]classTanh(Module):r"""Applies the Hyperbolic Tangent (Tanh) function element-wise. Tanh is defined as: .. math:: \text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)} Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. .. image:: ../scripts/activation_images/Tanh.png Examples:: >>> m = nn.Tanh() >>> input = torch.randn(2) >>> output = m(input) """defforward(self,input:Tensor)->Tensor:returntorch.tanh(input)
[docs]classSiLU(Module):r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise. The SiLU function is also known as the swish function. .. math:: \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.} .. note:: See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_ where the SiLU (Sigmoid Linear Unit) was originally coined, and see `Sigmoid-Weighted Linear Units for Neural Network Function Approximation in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish: a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_ where the SiLU was experimented with later. Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. .. image:: ../scripts/activation_images/SiLU.png Examples:: >>> m = nn.SiLU() >>> input = torch.randn(2) >>> output = m(input) """__constants__=['inplace']inplace:booldef__init__(self,inplace:bool=False):super().__init__()self.inplace=inplacedefforward(self,input:Tensor)->Tensor:returnF.silu(input,inplace=self.inplace)defextra_repr(self)->str:inplace_str='inplace=True'ifself.inplaceelse''returninplace_str
[docs]classMish(Module):r"""Applies the Mish function, element-wise. Mish: A Self Regularized Non-Monotonic Neural Activation Function. .. math:: \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x)) .. note:: See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_ Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. .. image:: ../scripts/activation_images/Mish.png Examples:: >>> m = nn.Mish() >>> input = torch.randn(2) >>> output = m(input) """__constants__=['inplace']inplace:booldef__init__(self,inplace:bool=False):super().__init__()self.inplace=inplacedefforward(self,input:Tensor)->Tensor:returnF.mish(input,inplace=self.inplace)defextra_repr(self)->str:inplace_str='inplace=True'ifself.inplaceelse''returninplace_str
[docs]classHardswish(Module):r"""Applies the Hardswish function, element-wise, as described in the paper: `Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`_. Hardswish is defined as: .. math:: \text{Hardswish}(x) = \begin{cases} 0 & \text{if~} x \le -3, \\ x & \text{if~} x \ge +3, \\ x \cdot (x + 3) /6 & \text{otherwise} \end{cases} Args: inplace: can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. .. image:: ../scripts/activation_images/Hardswish.png Examples:: >>> m = nn.Hardswish() >>> input = torch.randn(2) >>> output = m(input) """__constants__=['inplace']inplace:booldef__init__(self,inplace:bool=False)->None:super().__init__()self.inplace=inplacedefforward(self,input:Tensor)->Tensor:returnF.hardswish(input,self.inplace)
[docs]classELU(Module):r"""Applies the Exponential Linear Unit (ELU) function, element-wise, as described in the paper: `Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs) <https://arxiv.org/abs/1511.07289>`__. ELU is defined as: .. math:: \text{ELU}(x) = \begin{cases} x, & \text{ if } x > 0\\ \alpha * (\exp(x) - 1), & \text{ if } x \leq 0 \end{cases} Args: alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0 inplace: can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. .. image:: ../scripts/activation_images/ELU.png Examples:: >>> m = nn.ELU() >>> input = torch.randn(2) >>> output = m(input) """__constants__=['alpha','inplace']alpha:floatinplace:booldef__init__(self,alpha:float=1.,inplace:bool=False)->None:super().__init__()self.alpha=alphaself.inplace=inplacedefforward(self,input:Tensor)->Tensor:returnF.elu(input,self.alpha,self.inplace)defextra_repr(self)->str:inplace_str=', inplace=True'ifself.inplaceelse''returnf'alpha={self.alpha}{inplace_str}'
[docs]classCELU(Module):r"""Applies the element-wise function: .. math:: \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1)) More details can be found in the paper `Continuously Differentiable Exponential Linear Units`_ . Args: alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0 inplace: can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. .. image:: ../scripts/activation_images/CELU.png Examples:: >>> m = nn.CELU() >>> input = torch.randn(2) >>> output = m(input) .. _`Continuously Differentiable Exponential Linear Units`: https://arxiv.org/abs/1704.07483 """__constants__=['alpha','inplace']alpha:floatinplace:booldef__init__(self,alpha:float=1.,inplace:bool=False)->None:super().__init__()self.alpha=alphaself.inplace=inplacedefforward(self,input:Tensor)->Tensor:returnF.celu(input,self.alpha,self.inplace)defextra_repr(self)->str:inplace_str=', inplace=True'ifself.inplaceelse''returnf'alpha={self.alpha}{inplace_str}'
[docs]classSELU(Module):r"""Applied element-wise, as: .. math:: \text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1))) with :math:`\alpha = 1.6732632423543772848170429916717` and :math:`\text{scale} = 1.0507009873554804934193349852946`. .. warning:: When using ``kaiming_normal`` or ``kaiming_normal_`` for initialisation, ``nonlinearity='linear'`` should be used instead of ``nonlinearity='selu'`` in order to get `Self-Normalizing Neural Networks`_. See :func:`torch.nn.init.calculate_gain` for more information. More details can be found in the paper `Self-Normalizing Neural Networks`_ . Args: inplace (bool, optional): can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. .. image:: ../scripts/activation_images/SELU.png Examples:: >>> m = nn.SELU() >>> input = torch.randn(2) >>> output = m(input) .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515 """__constants__=['inplace']inplace:booldef__init__(self,inplace:bool=False)->None:super().__init__()self.inplace=inplacedefforward(self,input:Tensor)->Tensor:returnF.selu(input,self.inplace)defextra_repr(self)->str:inplace_str='inplace=True'ifself.inplaceelse''returninplace_str
[docs]classGLU(Module):r"""Applies the gated linear unit function :math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half of the input matrices and :math:`b` is the second half. Args: dim (int): the dimension on which to split the input. Default: -1 Shape: - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional dimensions - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2` Examples:: >>> m = nn.GLU() >>> input = torch.randn(4, 2) >>> output = m(input) """__constants__=['dim']dim:intdef__init__(self,dim:int=-1)->None:super().__init__()self.dim=dimdefforward(self,input:Tensor)->Tensor:returnF.glu(input,self.dim)defextra_repr(self)->str:returnf'dim={self.dim}'
[docs]classGELU(Module):r"""Applies the Gaussian Error Linear Units function: .. math:: \text{GELU}(x) = x * \Phi(x) where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution. When the approximate argument is 'tanh', Gelu is estimated with: .. math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt{2 / \pi} * (x + 0.044715 * x^3))) Args: approximate (str, optional): the gelu approximation algorithm to use: ``'none'`` | ``'tanh'``. Default: ``'none'`` Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. .. image:: ../scripts/activation_images/GELU.png Examples:: >>> m = nn.GELU() >>> input = torch.randn(2) >>> output = m(input) """__constants__=['approximate']approximate:strdef__init__(self,approximate:str='none')->None:super().__init__()self.approximate=approximatedefforward(self,input:Tensor)->Tensor:returnF.gelu(input,approximate=self.approximate)defextra_repr(self)->str:returnf'approximate={repr(self.approximate)}'
[docs]classHardshrink(Module):r"""Applies the Hard Shrinkage (Hardshrink) function element-wise. Hardshrink is defined as: .. math:: \text{HardShrink}(x) = \begin{cases} x, & \text{ if } x > \lambda \\ x, & \text{ if } x < -\lambda \\ 0, & \text{ otherwise } \end{cases} Args: lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5 Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. .. image:: ../scripts/activation_images/Hardshrink.png Examples:: >>> m = nn.Hardshrink() >>> input = torch.randn(2) >>> output = m(input) """__constants__=['lambd']lambd:floatdef__init__(self,lambd:float=0.5)->None:super().__init__()self.lambd=lambddefforward(self,input:Tensor)->Tensor:returnF.hardshrink(input,self.lambd)defextra_repr(self)->str:returnf'{self.lambd}'
[docs]classLeakyReLU(Module):r"""Applies the element-wise function: .. math:: \text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x) or .. math:: \text{LeakyReLU}(x) = \begin{cases} x, & \text{ if } x \geq 0 \\ \text{negative\_slope} \times x, & \text{ otherwise } \end{cases} Args: negative_slope: Controls the angle of the negative slope (which is used for negative input values). Default: 1e-2 inplace: can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(*)` where `*` means, any number of additional dimensions - Output: :math:`(*)`, same shape as the input .. image:: ../scripts/activation_images/LeakyReLU.png Examples:: >>> m = nn.LeakyReLU(0.1) >>> input = torch.randn(2) >>> output = m(input) """__constants__=['inplace','negative_slope']inplace:boolnegative_slope:floatdef__init__(self,negative_slope:float=1e-2,inplace:bool=False)->None:super().__init__()self.negative_slope=negative_slopeself.inplace=inplacedefforward(self,input:Tensor)->Tensor:returnF.leaky_relu(input,self.negative_slope,self.inplace)defextra_repr(self)->str:inplace_str=', inplace=True'ifself.inplaceelse''returnf'negative_slope={self.negative_slope}{inplace_str}'
[docs]classLogSigmoid(Module):r"""Applies the element-wise function: .. math:: \text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right) Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. .. image:: ../scripts/activation_images/LogSigmoid.png Examples:: >>> m = nn.LogSigmoid() >>> input = torch.randn(2) >>> output = m(input) """defforward(self,input:Tensor)->Tensor:returnF.logsigmoid(input)
[docs]classSoftplus(Module):r"""Applies the Softplus function :math:`\text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))` element-wise. SoftPlus is a smooth approximation to the ReLU function and can be used to constrain the output of a machine to always be positive. For numerical stability the implementation reverts to the linear function when :math:`input \times \beta > threshold`. Args: beta: the :math:`\beta` value for the Softplus formulation. Default: 1 threshold: values above this revert to a linear function. Default: 20 Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. .. image:: ../scripts/activation_images/Softplus.png Examples:: >>> m = nn.Softplus() >>> input = torch.randn(2) >>> output = m(input) """__constants__=['beta','threshold']beta:intthreshold:intdef__init__(self,beta:int=1,threshold:int=20)->None:super().__init__()self.beta=betaself.threshold=thresholddefforward(self,input:Tensor)->Tensor:returnF.softplus(input,self.beta,self.threshold)defextra_repr(self)->str:returnf'beta={self.beta}, threshold={self.threshold}'
[docs]classSoftshrink(Module):r"""Applies the soft shrinkage function elementwise: .. math:: \text{SoftShrinkage}(x) = \begin{cases} x - \lambda, & \text{ if } x > \lambda \\ x + \lambda, & \text{ if } x < -\lambda \\ 0, & \text{ otherwise } \end{cases} Args: lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5 Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. .. image:: ../scripts/activation_images/Softshrink.png Examples:: >>> m = nn.Softshrink() >>> input = torch.randn(2) >>> output = m(input) """__constants__=['lambd']lambd:floatdef__init__(self,lambd:float=0.5)->None:super().__init__()self.lambd=lambddefforward(self,input:Tensor)->Tensor:returnF.softshrink(input,self.lambd)defextra_repr(self)->str:returnstr(self.lambd)
[docs]classMultiheadAttention(Module):r"""Allows the model to jointly attend to information from different representation subspaces as described in the paper: `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_. Multi-Head Attention is defined as: .. math:: \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. ``nn.MultiHeadAttention`` will use the optimized implementations of ``scaled_dot_product_attention()`` when possible. In addition to support for the new ``scaled_dot_product_attention()`` function, for speeding up Inference, MHA will use fastpath inference with support for Nested Tensors, iff: - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor). - inputs are batched (3D) with ``batch_first==True`` - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad`` - training is disabled (using ``.eval()``) - ``add_bias_kv`` is ``False`` - ``add_zero_attn`` is ``False`` - ``batch_first`` is ``True`` and the input is batched - ``kdim`` and ``vdim`` are equal to ``embed_dim`` - if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask`` nor ``attn_mask`` is passed - autocast is disabled If the optimized inference fastpath implementation is in use, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for ``query``/``key``/``value`` to represent padding more efficiently than using a padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ will be returned, and an additional speedup proportional to the fraction of the input that is padding can be expected. Args: embed_dim: Total dimension of the model. num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout). bias: If specified, adds bias to input / output projection layers. Default: ``True``. add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``. add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1. Default: ``False``. kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``). vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``). batch_first: If ``True``, then the input and output tensors are provided as (batch, seq, feature). Default: ``False`` (seq, batch, feature). Examples:: >>> # xdoctest: +SKIP >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value) .. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`: https://arxiv.org/abs/2205.14135 """__constants__=['batch_first']bias_k:Optional[torch.Tensor]bias_v:Optional[torch.Tensor]def__init__(self,embed_dim,num_heads,dropout=0.,bias=True,add_bias_kv=False,add_zero_attn=False,kdim=None,vdim=None,batch_first=False,device=None,dtype=None)->None:ifembed_dim<=0ornum_heads<=0:raiseValueError(f"embed_dim and num_heads must be greater than 0,"f" got embed_dim={embed_dim} and num_heads={num_heads} instead")factory_kwargs={'device':device,'dtype':dtype}super().__init__()self.embed_dim=embed_dimself.kdim=kdimifkdimisnotNoneelseembed_dimself.vdim=vdimifvdimisnotNoneelseembed_dimself._qkv_same_embed_dim=self.kdim==embed_dimandself.vdim==embed_dimself.num_heads=num_headsself.dropout=dropoutself.batch_first=batch_firstself.head_dim=embed_dim//num_headsassertself.head_dim*num_heads==self.embed_dim,"embed_dim must be divisible by num_heads"ifnotself._qkv_same_embed_dim:self.q_proj_weight=Parameter(torch.empty((embed_dim,embed_dim),**factory_kwargs))self.k_proj_weight=Parameter(torch.empty((embed_dim,self.kdim),**factory_kwargs))self.v_proj_weight=Parameter(torch.empty((embed_dim,self.vdim),**factory_kwargs))self.register_parameter('in_proj_weight',None)else:self.in_proj_weight=Parameter(torch.empty((3*embed_dim,embed_dim),**factory_kwargs))self.register_parameter('q_proj_weight',None)self.register_parameter('k_proj_weight',None)self.register_parameter('v_proj_weight',None)ifbias:self.in_proj_bias=Parameter(torch.empty(3*embed_dim,**factory_kwargs))else:self.register_parameter('in_proj_bias',None)self.out_proj=NonDynamicallyQuantizableLinear(embed_dim,embed_dim,bias=bias,**factory_kwargs)ifadd_bias_kv:self.bias_k=Parameter(torch.empty((1,1,embed_dim),**factory_kwargs))self.bias_v=Parameter(torch.empty((1,1,embed_dim),**factory_kwargs))else:self.bias_k=self.bias_v=Noneself.add_zero_attn=add_zero_attnself._reset_parameters()def_reset_parameters(self):ifself._qkv_same_embed_dim:xavier_uniform_(self.in_proj_weight)else:xavier_uniform_(self.q_proj_weight)xavier_uniform_(self.k_proj_weight)xavier_uniform_(self.v_proj_weight)ifself.in_proj_biasisnotNone:constant_(self.in_proj_bias,0.)constant_(self.out_proj.bias,0.)ifself.bias_kisnotNone:xavier_normal_(self.bias_k)ifself.bias_visnotNone:xavier_normal_(self.bias_v)def__setstate__(self,state):# Support loading old MultiheadAttention checkpoints generated by v1.1.0if'_qkv_same_embed_dim'notinstate:state['_qkv_same_embed_dim']=Truesuper().__setstate__(state)
[docs]defforward(self,query:Tensor,key:Tensor,value:Tensor,key_padding_mask:Optional[Tensor]=None,need_weights:bool=True,attn_mask:Optional[Tensor]=None,average_attn_weights:bool=True,is_causal:bool=False)->Tuple[Tensor,Optional[Tensor]]:r""" Args: query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. Queries are compared against key-value pairs to produce the output. See "Attention Is All You Need" for more details. key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. See "Attention Is All You Need" for more details. value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. See "Attention Is All You Need" for more details. key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention`` and achieve the best performance for MHA. Default: ``True``. attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the corresponding position is not allowed to attend. For a float mask, the mask values will be added to the attention weight. If both attn_mask and key_padding_mask are supplied, their types should match. average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) is_causal: If specified, applies a causal mask as attention mask. Default: ``False``. Warning: ``is_causal`` provides a hint that ``attn_mask`` is the causal mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility. Outputs: - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the embedding dimension ``embed_dim``. - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. .. note:: `batch_first` argument is ignored for unbatched inputs. """why_not_fast_path=''if((attn_maskisnotNoneandtorch.is_floating_point(attn_mask))or(key_padding_maskisnotNone)andtorch.is_floating_point(key_padding_mask)):why_not_fast_path="floating-point masks are not supported for fast path."is_batched=query.dim()==3key_padding_mask=F._canonical_mask(mask=key_padding_mask,mask_name="key_padding_mask",other_type=F._none_or_dtype(attn_mask),other_name="attn_mask",target_type=query.dtype)attn_mask=F._canonical_mask(mask=attn_mask,mask_name="attn_mask",other_type=None,other_name="",target_type=query.dtype,check_other=False,)ifnotis_batched:why_not_fast_path=f"input not batched; expected query.dim() of 3 but got {query.dim()}"elifqueryisnotkeyorkeyisnotvalue:# When lifting this restriction, don't forget to either# enforce that the dtypes all match or test cases where# they don't!why_not_fast_path="non-self attention was used (query, key, and value are not the same Tensor)"elifself.in_proj_biasisnotNoneandquery.dtype!=self.in_proj_bias.dtype:why_not_fast_path=f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"elifself.in_proj_weightisNone:why_not_fast_path="in_proj_weight was None"elifquery.dtype!=self.in_proj_weight.dtype:# this case will fail anyway, but at least they'll get a useful error message.why_not_fast_path=f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"elifself.training:why_not_fast_path="training is enabled"elif(self.num_heads%2)!=0:why_not_fast_path="self.num_heads is not even"elifnotself.batch_first:why_not_fast_path="batch_first was not True"elifself.bias_kisnotNone:why_not_fast_path="self.bias_k was not None"elifself.bias_visnotNone:why_not_fast_path="self.bias_v was not None"elifself.add_zero_attn:why_not_fast_path="add_zero_attn was enabled"elifnotself._qkv_same_embed_dim:why_not_fast_path="_qkv_same_embed_dim was not True"elifquery.is_nestedand(key_padding_maskisnotNoneorattn_maskisnotNone):why_not_fast_path="supplying both src_key_padding_mask and src_mask at the same time \ is not supported with NestedTensor input"eliftorch.is_autocast_enabled():why_not_fast_path="autocast is enabled"ifnotwhy_not_fast_path:tensor_args=(query,key,value,self.in_proj_weight,self.in_proj_bias,self.out_proj.weight,self.out_proj.bias,)# We have to use list comprehensions below because TorchScript does not support# generator expressions.iftorch.overrides.has_torch_function(tensor_args):why_not_fast_path="some Tensor argument has_torch_function"elif_is_make_fx_tracing():why_not_fast_path="we are running make_fx tracing"elifnotall(_check_arg_device(x)forxintensor_args):why_not_fast_path=("some Tensor argument's device is neither one of "f"cpu, cuda or {torch.utils.backend_registration._privateuse1_backend_name}")eliftorch.is_grad_enabled()andany(_arg_requires_grad(x)forxintensor_args):why_not_fast_path=("grad is enabled and at least one of query or the ""input/output projection weights or biases requires_grad")ifnotwhy_not_fast_path:merged_mask,mask_type=self.merge_masks(attn_mask,key_padding_mask,query)ifself.in_proj_biasisnotNoneandself.in_proj_weightisnotNone:returntorch._native_multi_head_attention(query,key,value,self.embed_dim,self.num_heads,self.in_proj_weight,self.in_proj_bias,self.out_proj.weight,self.out_proj.bias,merged_mask,need_weights,average_attn_weights,mask_type)any_nested=query.is_nestedorkey.is_nestedorvalue.is_nestedassertnotany_nested,("MultiheadAttention does not support NestedTensor outside of its fast path. "+f"The fast path was not hit because {why_not_fast_path}")ifself.batch_firstandis_batched:# make sure that the transpose op does not affect the "is" propertyifkeyisvalue:ifqueryiskey:query=key=value=query.transpose(1,0)else:query,key=(x.transpose(1,0)forxin(query,key))value=keyelse:query,key,value=(x.transpose(1,0)forxin(query,key,value))ifnotself._qkv_same_embed_dim:attn_output,attn_output_weights=F.multi_head_attention_forward(query,key,value,self.embed_dim,self.num_heads,self.in_proj_weight,self.in_proj_bias,self.bias_k,self.bias_v,self.add_zero_attn,self.dropout,self.out_proj.weight,self.out_proj.bias,training=self.training,key_padding_mask=key_padding_mask,need_weights=need_weights,attn_mask=attn_mask,use_separate_proj_weight=True,q_proj_weight=self.q_proj_weight,k_proj_weight=self.k_proj_weight,v_proj_weight=self.v_proj_weight,average_attn_weights=average_attn_weights,is_causal=is_causal)else:attn_output,attn_output_weights=F.multi_head_attention_forward(query,key,value,self.embed_dim,self.num_heads,self.in_proj_weight,self.in_proj_bias,self.bias_k,self.bias_v,self.add_zero_attn,self.dropout,self.out_proj.weight,self.out_proj.bias,training=self.training,key_padding_mask=key_padding_mask,need_weights=need_weights,attn_mask=attn_mask,average_attn_weights=average_attn_weights,is_causal=is_causal)ifself.batch_firstandis_batched:returnattn_output.transpose(1,0),attn_output_weightselse:returnattn_output,attn_output_weights
[docs]defmerge_masks(self,attn_mask:Optional[Tensor],key_padding_mask:Optional[Tensor],query:Tensor)->Tuple[Optional[Tensor],Optional[int]]:r""" Determine mask type and combine masks if necessary. If only one mask is provided, that mask and the corresponding mask type will be returned. If both masks are provided, they will be both expanded to shape ``(batch_size, num_heads, seq_len, seq_len)``, combined with logical ``or`` and mask type 2 will be returned Args: attn_mask: attention mask of shape ``(seq_len, seq_len)``, mask type 0 key_padding_mask: padding mask of shape ``(batch_size, seq_len)``, mask type 1 query: query embeddings of shape ``(batch_size, seq_len, embed_dim)`` Returns: merged_mask: merged mask mask_type: merged mask type (0, 1, or 2) """mask_type:Optional[int]=Nonemerged_mask:Optional[Tensor]=Noneifkey_padding_maskisnotNone:mask_type=1merged_mask=key_padding_maskifattn_maskisnotNone:# In this branch query can't be a nested tensor, so it has a shapebatch_size,seq_len,_=query.shapemask_type=2# Always expands attn_mask to 4Difattn_mask.dim()==3:attn_mask_expanded=attn_mask.view(batch_size,-1,seq_len,seq_len)else:# attn_mask.dim() == 2:attn_mask_expanded=attn_mask.view(1,1,seq_len,seq_len).expand(batch_size,self.num_heads,-1,-1)merged_mask=attn_mask_expandedifkey_padding_maskisnotNone:key_padding_mask_expanded=key_padding_mask.view(batch_size,1,1,seq_len).expand(-1,self.num_heads,-1,-1)merged_mask=attn_mask_expanded+key_padding_mask_expanded# no attn_mask and no key_padding_mask, returns None, Nonereturnmerged_mask,mask_type
[docs]classPReLU(Module):r"""Applies the element-wise function: .. math:: \text{PReLU}(x) = \max(0,x) + a * \min(0,x) or .. math:: \text{PReLU}(x) = \begin{cases} x, & \text{ if } x \geq 0 \\ ax, & \text{ otherwise } \end{cases} Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`, a separate :math:`a` is used for each input channel. .. note:: weight decay should not be used when learning :math:`a` for good performance. .. note:: Channel dim is the 2nd dim of input. When input has dims < 2, then there is no channel dim and the number of channels = 1. Args: num_parameters (int): number of :math:`a` to learn. Although it takes an int as input, there is only two values are legitimate: 1, or the number of channels at input. Default: 1 init (float): the initial value of :math:`a`. Default: 0.25 Shape: - Input: :math:`( *)` where `*` means, any number of additional dimensions. - Output: :math:`(*)`, same shape as the input. Attributes: weight (Tensor): the learnable weights of shape (:attr:`num_parameters`). .. image:: ../scripts/activation_images/PReLU.png Examples:: >>> m = nn.PReLU() >>> input = torch.randn(2) >>> output = m(input) """__constants__=['num_parameters']num_parameters:intdef__init__(self,num_parameters:int=1,init:float=0.25,device=None,dtype=None)->None:factory_kwargs={'device':device,'dtype':dtype}self.num_parameters=num_parameterssuper().__init__()self.init=initself.weight=Parameter(torch.empty(num_parameters,**factory_kwargs))self.reset_parameters()defreset_parameters(self):torch.nn.init.constant_(self.weight,self.init)defforward(self,input:Tensor)->Tensor:returnF.prelu(input,self.weight)defextra_repr(self)->str:returnf'num_parameters={self.num_parameters}'
[docs]classSoftsign(Module):r"""Applies the element-wise function: .. math:: \text{SoftSign}(x) = \frac{x}{ 1 + |x|} Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. .. image:: ../scripts/activation_images/Softsign.png Examples:: >>> m = nn.Softsign() >>> input = torch.randn(2) >>> output = m(input) """defforward(self,input:Tensor)->Tensor:returnF.softsign(input)
[docs]classTanhshrink(Module):r"""Applies the element-wise function: .. math:: \text{Tanhshrink}(x) = x - \tanh(x) Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. .. image:: ../scripts/activation_images/Tanhshrink.png Examples:: >>> m = nn.Tanhshrink() >>> input = torch.randn(2) >>> output = m(input) """defforward(self,input:Tensor)->Tensor:returnF.tanhshrink(input)
[docs]classSoftmin(Module):r"""Applies the Softmin function to an n-dimensional input Tensor rescaling them so that the elements of the n-dimensional output Tensor lie in the range `[0, 1]` and sum to 1. Softmin is defined as: .. math:: \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)} Shape: - Input: :math:`(*)` where `*` means, any number of additional dimensions - Output: :math:`(*)`, same shape as the input Args: dim (int): A dimension along which Softmin will be computed (so every slice along dim will sum to 1). Returns: a Tensor of the same dimension and shape as the input, with values in the range [0, 1] Examples:: >>> m = nn.Softmin(dim=1) >>> input = torch.randn(2, 3) >>> output = m(input) """__constants__=['dim']dim:Optional[int]def__init__(self,dim:Optional[int]=None)->None:super().__init__()self.dim=dimdef__setstate__(self,state):super().__setstate__(state)ifnothasattr(self,'dim'):self.dim=Nonedefforward(self,input:Tensor)->Tensor:returnF.softmin(input,self.dim,_stacklevel=5)defextra_repr(self):returnf'dim={self.dim}'
[docs]classSoftmax(Module):r"""Applies the Softmax function to an n-dimensional input Tensor rescaling them so that the elements of the n-dimensional output Tensor lie in the range [0,1] and sum to 1. Softmax is defined as: .. math:: \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} When the input Tensor is a sparse tensor then the unspecified values are treated as ``-inf``. Shape: - Input: :math:`(*)` where `*` means, any number of additional dimensions - Output: :math:`(*)`, same shape as the input Returns: a Tensor of the same dimension and shape as the input with values in the range [0, 1] Args: dim (int): A dimension along which Softmax will be computed (so every slice along dim will sum to 1). .. note:: This module doesn't work directly with NLLLoss, which expects the Log to be computed between the Softmax and itself. Use `LogSoftmax` instead (it's faster and has better numerical properties). Examples:: >>> m = nn.Softmax(dim=1) >>> input = torch.randn(2, 3) >>> output = m(input) """__constants__=['dim']dim:Optional[int]def__init__(self,dim:Optional[int]=None)->None:super().__init__()self.dim=dimdef__setstate__(self,state):super().__setstate__(state)ifnothasattr(self,'dim'):self.dim=Nonedefforward(self,input:Tensor)->Tensor:returnF.softmax(input,self.dim,_stacklevel=5)defextra_repr(self)->str:returnf'dim={self.dim}'
[docs]classSoftmax2d(Module):r"""Applies SoftMax over features to each spatial location. When given an image of ``Channels x Height x Width``, it will apply `Softmax` to each location :math:`(Channels, h_i, w_j)` Shape: - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`. - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input) Returns: a Tensor of the same dimension and shape as the input with values in the range [0, 1] Examples:: >>> m = nn.Softmax2d() >>> # you softmax over the 2nd dimension >>> input = torch.randn(2, 3, 12, 13) >>> output = m(input) """defforward(self,input:Tensor)->Tensor:ifinput.dim()notin(3,4):raiseValueError(f"Softmax2d: expected input to be 3D or 4D, got {input.dim()}D instead")returnF.softmax(input,-3,_stacklevel=5)
[docs]classLogSoftmax(Module):r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional input Tensor. The LogSoftmax formulation can be simplified as: .. math:: \text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right) Shape: - Input: :math:`(*)` where `*` means, any number of additional dimensions - Output: :math:`(*)`, same shape as the input Args: dim (int): A dimension along which LogSoftmax will be computed. Returns: a Tensor of the same dimension and shape as the input with values in the range [-inf, 0) Examples:: >>> m = nn.LogSoftmax(dim=1) >>> input = torch.randn(2, 3) >>> output = m(input) """__constants__=['dim']dim:Optional[int]def__init__(self,dim:Optional[int]=None)->None:super().__init__()self.dim=dimdef__setstate__(self,state):super().__setstate__(state)ifnothasattr(self,'dim'):self.dim=Nonedefforward(self,input:Tensor)->Tensor:returnF.log_softmax(input,self.dim,_stacklevel=5)defextra_repr(self):returnf'dim={self.dim}'
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.