Source code for torch.nn.init
# mypy: allow-untyped-defs
"""This file contains utilities for initializing neural network parameters."""
import math
import warnings
from torch import Tensor
import torch
from typing import Optional as _Optional
# These no_grad_* functions are necessary as wrappers around the parts of these
# functions that use `with torch.no_grad()`. The JIT doesn't support context
# managers, so these need to be implemented as builtins. Using these wrappers
# lets us keep those builtins small and re-usable.
def _no_grad_uniform_(tensor, a, b, generator=None):
with torch.no_grad():
return tensor.uniform_(a, b, generator=generator)
def _no_grad_normal_(tensor, mean, std, generator=None):
with torch.no_grad():
return tensor.normal_(mean, std, generator=generator)
def _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=None):
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1, generator=generator)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def _no_grad_fill_(tensor, val):
with torch.no_grad():
return tensor.fill_(val)
def _no_grad_zero_(tensor):
with torch.no_grad():
return tensor.zero_()
[docs]def calculate_gain(nonlinearity, param=None):
r"""Return the recommended gain value for the given nonlinearity function.
The values are as follows:
================= ====================================================
nonlinearity gain
================= ====================================================
Linear / Identity :math:`1`
Conv{1,2,3}D :math:`1`
Sigmoid :math:`1`
Tanh :math:`\frac{5}{3}`
ReLU :math:`\sqrt{2}`
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
SELU :math:`\frac{3}{4}`
================= ====================================================
.. warning::
In order to implement `Self-Normalizing Neural Networks`_ ,
you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``.
This gives the initial weights a variance of ``1 / N``,
which is necessary to induce a stable fixed point in the forward pass.
In contrast, the default gain for ``SELU`` sacrifices the normalization
effect for more stable gradient flow in rectangular layers.
Args:
nonlinearity: the non-linear function (`nn.functional` name)
param: optional parameter for the non-linear function
Examples:
>>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
.. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html
"""
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
return 1
elif nonlinearity == 'tanh':
return 5.0 / 3
elif nonlinearity == 'relu':
return math.sqrt(2.0)
elif nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
# True/False are instances of int, hence check above
negative_slope = param
else:
raise ValueError(f"negative_slope {param} not a valid number")
return math.sqrt(2.0 / (1 + negative_slope ** 2))
elif nonlinearity == 'selu':
return 3.0 / 4 # Value found empirically (https://github.com/pytorch/pytorch/pull/50664)
else:
raise ValueError(f"Unsupported nonlinearity {nonlinearity}")
[docs]def uniform_(
tensor: Tensor,
a: float = 0.0,
b: float = 1.0,
generator: _Optional[torch.Generator] = None,
) -> Tensor:
r"""Fill the input Tensor with values drawn from the uniform distribution.
:math:`\mathcal{U}(a, b)`.
Args:
tensor: an n-dimensional `torch.Tensor`
a: the lower bound of the uniform distribution
b: the upper bound of the uniform distribution
generator: the torch Generator to sample from (default: None)
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.uniform_(w)
"""
if torch.overrides.has_torch_function_variadic(tensor):
return torch.overrides.handle_torch_function(
uniform_, (tensor,), tensor=tensor, a=a, b=b, generator=generator
)
return _no_grad_uniform_(tensor, a, b, generator)
[docs]def normal_(
tensor: Tensor,
mean: float = 0.0,
std: float = 1.0,
generator: _Optional[torch.Generator] = None,
) -> Tensor:
r"""Fill the input Tensor with values drawn from the normal distribution.
:math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
generator: the torch Generator to sample from (default: None)
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.normal_(w)
"""
if torch.overrides.has_torch_function_variadic(tensor):
return torch.overrides.handle_torch_function(
normal_, (tensor,), tensor=tensor, mean=mean, std=std, generator=generator
)
return _no_grad_normal_(tensor, mean, std, generator)
[docs]def trunc_normal_(
tensor: Tensor,
mean: float = 0.,
std: float = 1.,
a: float = -2.,
b: float = 2.,
generator: _Optional[torch.Generator] = None
) -> Tensor:
r"""Fill the input Tensor with values drawn from a truncated normal distribution.
The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
generator: the torch Generator to sample from (default: None)
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=generator)
[docs]def constant_(tensor: Tensor, val: float) -> Tensor:
r"""Fill the input Tensor with the value :math:`\text{val}`.
Args:
tensor: an n-dimensional `torch.Tensor`
val: the value to fill the tensor with
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.constant_(w, 0.3)
"""
if torch.overrides.has_torch_function_variadic(tensor):
return torch.overrides.handle_torch_function(constant_, (tensor,), tensor=tensor, val=val)
return _no_grad_fill_(tensor, val)
[docs]def ones_(tensor: Tensor) -> Tensor:
r"""Fill the input Tensor with the scalar value `1`.
Args:
tensor: an n-dimensional `torch.Tensor`
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.ones_(w)
"""
return _no_grad_fill_(tensor, 1.)
[docs]def zeros_(tensor: Tensor) -> Tensor:
r"""Fill the input Tensor with the scalar value `0`.
Args:
tensor: an n-dimensional `torch.Tensor`
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.zeros_(w)
"""
return _no_grad_zero_(tensor)
[docs]def eye_(tensor):
r"""Fill the 2-dimensional input `Tensor` with the identity matrix.
Preserves the identity of the inputs in `Linear` layers, where as
many inputs are preserved as possible.
Args:
tensor: a 2-dimensional `torch.Tensor`
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.eye_(w)
"""
if tensor.ndimension() != 2:
raise ValueError("Only tensors with 2 dimensions are supported")
with torch.no_grad():
torch.eye(*tensor.shape, out=tensor, requires_grad=tensor.requires_grad)
return tensor
[docs]def dirac_(tensor, groups=1):
r"""Fill the {3, 4, 5}-dimensional input `Tensor` with the Dirac delta function.
Preserves the identity of the inputs in `Convolutional`
layers, where as many input channels are preserved as possible. In case
of groups>1, each group of channels preserves identity
Args:
tensor: a {3, 4, 5}-dimensional `torch.Tensor`
groups (int, optional): number of groups in the conv layer (default: 1)
Examples:
>>> w = torch.empty(3, 16, 5, 5)
>>> nn.init.dirac_(w)
>>> w = torch.empty(3, 24, 5, 5)
>>> nn.init.dirac_(w, 3)
"""
dimensions = tensor.ndimension()
if dimensions not in [3, 4, 5]:
raise ValueError("Only tensors with 3, 4, or 5 dimensions are supported")
sizes = tensor.size()
if sizes[0] % groups != 0:
raise ValueError('dim 0 must be divisible by groups')
out_chans_per_grp = sizes[0] // groups
min_dim = min(out_chans_per_grp, sizes[1])
with torch.no_grad():
tensor.zero_()
for g in range(groups):
for d in range(min_dim):
if dimensions == 3: # Temporal convolution
tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2] = 1
elif dimensions == 4: # Spatial convolution
tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2,
tensor.size(3) // 2] = 1
else: # Volumetric convolution
tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2,
tensor.size(3) // 2, tensor.size(4) // 2] = 1
return tensor
def _calculate_fan_in_and_fan_out(tensor):
dimensions = tensor.dim()
if dimensions < 2:
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
num_input_fmaps = tensor.size(1)
num_output_fmaps = tensor.size(0)
receptive_field_size = 1
if tensor.dim() > 2:
# math.prod is not always available, accumulate the product manually
# we could use functools.reduce but that is not supported by TorchScript
for s in tensor.shape[2:]:
receptive_field_size *= s
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
[docs]def xavier_uniform_(
tensor: Tensor, gain: float = 1.0, generator: _Optional[torch.Generator] = None
) -> Tensor:
r"""Fill the input `Tensor` with values using a Xavier uniform distribution.
The method is described in `Understanding the difficulty of training
deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010).
The resulting tensor will have values sampled from
:math:`\mathcal{U}(-a, a)` where
.. math::
a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}}
Also known as Glorot initialization.
Args:
tensor: an n-dimensional `torch.Tensor`
gain: an optional scaling factor
generator: the torch Generator to sample from (default: None)
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))
"""
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
return _no_grad_uniform_(tensor, -a, a, generator)
[docs]def xavier_normal_(
tensor: Tensor,
gain: float = 1.0,
generator: _Optional[torch.Generator] = None,
) -> Tensor:
r"""Fill the input `Tensor` with values using a Xavier normal distribution.
The method is described in `Understanding the difficulty of training deep feedforward
neural networks` - Glorot, X. & Bengio, Y. (2010). The resulting tensor
will have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where
.. math::
\text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}}
Also known as Glorot initialization.
Args:
tensor: an n-dimensional `torch.Tensor`
gain: an optional scaling factor
generator: the torch Generator to sample from (default: None)
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.xavier_normal_(w)
"""
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
return _no_grad_normal_(tensor, 0., std, generator)
def _calculate_correct_fan(tensor, mode):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError(f"Mode {mode} not supported, please use one of {valid_modes}")
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
return fan_in if mode == 'fan_in' else fan_out
[docs]def kaiming_uniform_(
tensor: Tensor,
a: float = 0,
mode: str = "fan_in",
nonlinearity: str = "leaky_relu",
generator: _Optional[torch.Generator] = None,
):
r"""Fill the input `Tensor` with values using a Kaiming uniform distribution.
The method is described in `Delving deep into rectifiers: Surpassing
human-level performance on ImageNet classification` - He, K. et al. (2015).
The resulting tensor will have values sampled from
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where
.. math::
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
Also known as He initialization.
Args:
tensor: an n-dimensional `torch.Tensor`
a: the negative slope of the rectifier used after this layer (only
used with ``'leaky_relu'``)
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
nonlinearity: the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
generator: the torch Generator to sample from (default: None)
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
"""
if torch.overrides.has_torch_function_variadic(tensor):
return torch.overrides.handle_torch_function(
kaiming_uniform_,
(tensor,),
tensor=tensor,
a=a,
mode=mode,
nonlinearity=nonlinearity,
generator=generator)
if 0 in tensor.shape:
warnings.warn("Initializing zero-element tensors is a no-op")
return tensor
fan = _calculate_correct_fan(tensor, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
with torch.no_grad():
return tensor.uniform_(-bound, bound, generator=generator)
[docs]def kaiming_normal_(
tensor: Tensor,
a: float = 0,
mode: str = "fan_in",
nonlinearity: str = "leaky_relu",
generator: _Optional[torch.Generator] = None,
):
r"""Fill the input `Tensor` with values using a Kaiming normal distribution.
The method is described in `Delving deep into rectifiers: Surpassing
human-level performance on ImageNet classification` - He, K. et al. (2015).
The resulting tensor will have values sampled from
:math:`\mathcal{N}(0, \text{std}^2)` where
.. math::
\text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
Also known as He initialization.
Args:
tensor: an n-dimensional `torch.Tensor`
a: the negative slope of the rectifier used after this layer (only
used with ``'leaky_relu'``)
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
nonlinearity: the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
generator: the torch Generator to sample from (default: None)
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
"""
if 0 in tensor.shape:
warnings.warn("Initializing zero-element tensors is a no-op")
return tensor
fan = _calculate_correct_fan(tensor, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
with torch.no_grad():
return tensor.normal_(0, std, generator=generator)
[docs]def orthogonal_(
tensor,
gain=1,
generator: _Optional[torch.Generator] = None,
):
r"""Fill the input `Tensor` with a (semi) orthogonal matrix.
Described in `Exact solutions to the nonlinear dynamics of learning in deep
linear neural networks` - Saxe, A. et al. (2013). The input tensor must have
at least 2 dimensions, and for tensors with more than 2 dimensions the
trailing dimensions are flattened.
Args:
tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2`
gain: optional scaling factor
generator: the torch Generator to sample from (default: None)
Examples:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
>>> w = torch.empty(3, 5)
>>> nn.init.orthogonal_(w)
"""
if tensor.ndimension() < 2:
raise ValueError("Only tensors with 2 or more dimensions are supported")
if tensor.numel() == 0:
# no-op
return tensor
rows = tensor.size(0)
cols = tensor.numel() // rows
flattened = tensor.new(rows, cols).normal_(0, 1, generator=generator)
if rows < cols:
flattened.t_()
# Compute the qr factorization
q, r = torch.linalg.qr(flattened)
# Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
d = torch.diag(r, 0)
ph = d.sign()
q *= ph
if rows < cols:
q.t_()
with torch.no_grad():
tensor.view_as(q).copy_(q)
tensor.mul_(gain)
return tensor
[docs]def sparse_(
tensor,
sparsity,
std=0.01,
generator: _Optional[torch.Generator] = None,
):
r"""Fill the 2D input `Tensor` as a sparse matrix.
The non-zero elements will be drawn from the normal distribution
:math:`\mathcal{N}(0, 0.01)`, as described in `Deep learning via
Hessian-free optimization` - Martens, J. (2010).
Args:
tensor: an n-dimensional `torch.Tensor`
sparsity: The fraction of elements in each column to be set to zero
std: the standard deviation of the normal distribution used to generate
the non-zero values
generator: the torch Generator to sample from (default: None)
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.sparse_(w, sparsity=0.1)
"""
if tensor.ndimension() != 2:
raise ValueError("Only tensors with 2 dimensions are supported")
rows, cols = tensor.shape
num_zeros = int(math.ceil(sparsity * rows))
with torch.no_grad():
tensor.normal_(0, std, generator=generator)
for col_idx in range(cols):
row_indices = torch.randperm(rows)
zero_indices = row_indices[:num_zeros]
tensor[zero_indices, col_idx] = 0
return tensor
# for backward compatibility
def _make_deprecate(meth):
new_name = meth.__name__
old_name = new_name[:-1]
def deprecated_init(*args, **kwargs):
warnings.warn(
f"`nn.init.{old_name}` is now deprecated in favor of `nn.init.{new_name}`.",
FutureWarning,
stacklevel=2,
)
return meth(*args, **kwargs)
deprecated_init.__doc__ = fr"""
{old_name}(...)
.. warning::
This method is now deprecated in favor of :func:`torch.nn.init.{new_name}`.
See :func:`~torch.nn.init.{new_name}` for details."""
deprecated_init.__name__ = old_name
return deprecated_init
uniform = _make_deprecate(uniform_)
normal = _make_deprecate(normal_)
constant = _make_deprecate(constant_)
eye = _make_deprecate(eye_)
dirac = _make_deprecate(dirac_)
xavier_uniform = _make_deprecate(xavier_uniform_)
xavier_normal = _make_deprecate(xavier_normal_)
kaiming_uniform = _make_deprecate(kaiming_uniform_)
kaiming_normal = _make_deprecate(kaiming_normal_)
orthogonal = _make_deprecate(orthogonal_)
sparse = _make_deprecate(sparse_)