importmathimportwarningsfromtorchimportTensorimporttorch# These no_grad_* functions are necessary as wrappers around the parts of these# functions that use `with torch.no_grad()`. The JIT doesn't support context# managers, so these need to be implemented as builtins. Using these wrappers# lets us keep those builtins small and re-usable.def_no_grad_uniform_(tensor,a,b):withtorch.no_grad():returntensor.uniform_(a,b)def_no_grad_normal_(tensor,mean,std):withtorch.no_grad():returntensor.normal_(mean,std)def_no_grad_trunc_normal_(tensor,mean,std,a,b):# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdfdefnorm_cdf(x):# Computes standard normal cumulative distribution functionreturn(1.+math.erf(x/math.sqrt(2.)))/2.if(mean<a-2*std)or(mean>b+2*std):warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ""The distribution of values may be incorrect.",stacklevel=2)withtorch.no_grad():# Values are generated by using a truncated uniform distribution and# then using the inverse CDF for the normal distribution.# Get upper and lower cdf valuesl=norm_cdf((a-mean)/std)u=norm_cdf((b-mean)/std)# Uniformly fill tensor with values from [l, u], then translate to# [2l-1, 2u-1].tensor.uniform_(2*l-1,2*u-1)# Use inverse cdf transform for normal distribution to get truncated# standard normaltensor.erfinv_()# Transform to proper mean, stdtensor.mul_(std*math.sqrt(2.))tensor.add_(mean)# Clamp to ensure it's in the proper rangetensor.clamp_(min=a,max=b)returntensordef_no_grad_fill_(tensor,val):withtorch.no_grad():returntensor.fill_(val)def_no_grad_zero_(tensor):withtorch.no_grad():returntensor.zero_()
[docs]defcalculate_gain(nonlinearity,param=None):r"""Return the recommended gain value for the given nonlinearity function. The values are as follows: ================= ==================================================== nonlinearity gain ================= ==================================================== Linear / Identity :math:`1` Conv{1,2,3}D :math:`1` Sigmoid :math:`1` Tanh :math:`\frac{5}{3}` ReLU :math:`\sqrt{2}` Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` SELU :math:`\frac{3}{4}` ================= ==================================================== .. warning:: In order to implement `Self-Normalizing Neural Networks`_ , you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``. This gives the initial weights a variance of ``1 / N``, which is necessary to induce a stable fixed point in the forward pass. In contrast, the default gain for ``SELU`` sacrifices the normalisation effect for more stable gradient flow in rectangular layers. Args: nonlinearity: the non-linear function (`nn.functional` name) param: optional parameter for the non-linear function Examples: >>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html """linear_fns=['linear','conv1d','conv2d','conv3d','conv_transpose1d','conv_transpose2d','conv_transpose3d']ifnonlinearityinlinear_fnsornonlinearity=='sigmoid':return1elifnonlinearity=='tanh':return5.0/3elifnonlinearity=='relu':returnmath.sqrt(2.0)elifnonlinearity=='leaky_relu':ifparamisNone:negative_slope=0.01elifnotisinstance(param,bool)andisinstance(param,int)orisinstance(param,float):# True/False are instances of int, hence check abovenegative_slope=paramelse:raiseValueError("negative_slope {} not a valid number".format(param))returnmath.sqrt(2.0/(1+negative_slope**2))elifnonlinearity=='selu':return3.0/4# Value found empirically (https://github.com/pytorch/pytorch/pull/50664)else:raiseValueError("Unsupported nonlinearity {}".format(nonlinearity))
[docs]defuniform_(tensor:Tensor,a:float=0.,b:float=1.)->Tensor:r"""Fills the input Tensor with values drawn from the uniform distribution :math:`\mathcal{U}(a, b)`. Args: tensor: an n-dimensional `torch.Tensor` a: the lower bound of the uniform distribution b: the upper bound of the uniform distribution Examples: >>> w = torch.empty(3, 5) >>> nn.init.uniform_(w) """iftorch.overrides.has_torch_function_variadic(tensor):returntorch.overrides.handle_torch_function(uniform_,(tensor,),tensor=tensor,a=a,b=b)return_no_grad_uniform_(tensor,a,b)
[docs]defnormal_(tensor:Tensor,mean:float=0.,std:float=1.)->Tensor:r"""Fills the input Tensor with values drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`. Args: tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution std: the standard deviation of the normal distribution Examples: >>> w = torch.empty(3, 5) >>> nn.init.normal_(w) """iftorch.overrides.has_torch_function_variadic(tensor):returntorch.overrides.handle_torch_function(normal_,(tensor,),tensor=tensor,mean=mean,std=std)return_no_grad_normal_(tensor,mean,std)
deftrunc_normal_(tensor:Tensor,mean:float=0.,std:float=1.,a:float=-2.,b:float=2.)->Tensor:r"""Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \leq \text{mean} \leq b`. Args: tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution std: the standard deviation of the normal distribution a: the minimum cutoff value b: the maximum cutoff value Examples: >>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w) """return_no_grad_trunc_normal_(tensor,mean,std,a,b)
[docs]defconstant_(tensor:Tensor,val:float)->Tensor:r"""Fills the input Tensor with the value :math:`\text{val}`. Args: tensor: an n-dimensional `torch.Tensor` val: the value to fill the tensor with Examples: >>> w = torch.empty(3, 5) >>> nn.init.constant_(w, 0.3) """iftorch.overrides.has_torch_function_variadic(tensor):returntorch.overrides.handle_torch_function(constant_,(tensor,),tensor=tensor,val=val)return_no_grad_fill_(tensor,val)
[docs]defones_(tensor:Tensor)->Tensor:r"""Fills the input Tensor with the scalar value `1`. Args: tensor: an n-dimensional `torch.Tensor` Examples: >>> w = torch.empty(3, 5) >>> nn.init.ones_(w) """return_no_grad_fill_(tensor,1.)
[docs]defzeros_(tensor:Tensor)->Tensor:r"""Fills the input Tensor with the scalar value `0`. Args: tensor: an n-dimensional `torch.Tensor` Examples: >>> w = torch.empty(3, 5) >>> nn.init.zeros_(w) """return_no_grad_zero_(tensor)
[docs]defeye_(tensor):r"""Fills the 2-dimensional input `Tensor` with the identity matrix. Preserves the identity of the inputs in `Linear` layers, where as many inputs are preserved as possible. Args: tensor: a 2-dimensional `torch.Tensor` Examples: >>> w = torch.empty(3, 5) >>> nn.init.eye_(w) """iftensor.ndimension()!=2:raiseValueError("Only tensors with 2 dimensions are supported")withtorch.no_grad():torch.eye(*tensor.shape,out=tensor,requires_grad=tensor.requires_grad)returntensor
[docs]defdirac_(tensor,groups=1):r"""Fills the {3, 4, 5}-dimensional input `Tensor` with the Dirac delta function. Preserves the identity of the inputs in `Convolutional` layers, where as many input channels are preserved as possible. In case of groups>1, each group of channels preserves identity Args: tensor: a {3, 4, 5}-dimensional `torch.Tensor` groups (optional): number of groups in the conv layer (default: 1) Examples: >>> w = torch.empty(3, 16, 5, 5) >>> nn.init.dirac_(w) >>> w = torch.empty(3, 24, 5, 5) >>> nn.init.dirac_(w, 3) """dimensions=tensor.ndimension()ifdimensionsnotin[3,4,5]:raiseValueError("Only tensors with 3, 4, or 5 dimensions are supported")sizes=tensor.size()ifsizes[0]%groups!=0:raiseValueError('dim 0 must be divisible by groups')out_chans_per_grp=sizes[0]//groupsmin_dim=min(out_chans_per_grp,sizes[1])withtorch.no_grad():tensor.zero_()forginrange(groups):fordinrange(min_dim):ifdimensions==3:# Temporal convolutiontensor[g*out_chans_per_grp+d,d,tensor.size(2)//2]=1elifdimensions==4:# Spatial convolutiontensor[g*out_chans_per_grp+d,d,tensor.size(2)//2,tensor.size(3)//2]=1else:# Volumetric convolutiontensor[g*out_chans_per_grp+d,d,tensor.size(2)//2,tensor.size(3)//2,tensor.size(4)//2]=1returntensor
def_calculate_fan_in_and_fan_out(tensor):dimensions=tensor.dim()ifdimensions<2:raiseValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")num_input_fmaps=tensor.size(1)num_output_fmaps=tensor.size(0)receptive_field_size=1iftensor.dim()>2:# math.prod is not always available, accumulate the product manually# we could use functools.reduce but that is not supported by TorchScriptforsintensor.shape[2:]:receptive_field_size*=sfan_in=num_input_fmaps*receptive_field_sizefan_out=num_output_fmaps*receptive_field_sizereturnfan_in,fan_out
[docs]defxavier_uniform_(tensor:Tensor,gain:float=1.)->Tensor:r"""Fills the input `Tensor` with values according to the method described in `Understanding the difficulty of training deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform distribution. The resulting tensor will have values sampled from :math:`\mathcal{U}(-a, a)` where .. math:: a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}} Also known as Glorot initialization. Args: tensor: an n-dimensional `torch.Tensor` gain: an optional scaling factor Examples: >>> w = torch.empty(3, 5) >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu')) """fan_in,fan_out=_calculate_fan_in_and_fan_out(tensor)std=gain*math.sqrt(2.0/float(fan_in+fan_out))a=math.sqrt(3.0)*std# Calculate uniform bounds from standard deviationreturn_no_grad_uniform_(tensor,-a,a)
[docs]defxavier_normal_(tensor:Tensor,gain:float=1.)->Tensor:r"""Fills the input `Tensor` with values according to the method described in `Understanding the difficulty of training deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010), using a normal distribution. The resulting tensor will have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where .. math:: \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}} Also known as Glorot initialization. Args: tensor: an n-dimensional `torch.Tensor` gain: an optional scaling factor Examples: >>> w = torch.empty(3, 5) >>> nn.init.xavier_normal_(w) """fan_in,fan_out=_calculate_fan_in_and_fan_out(tensor)std=gain*math.sqrt(2.0/float(fan_in+fan_out))return_no_grad_normal_(tensor,0.,std)
def_calculate_correct_fan(tensor,mode):mode=mode.lower()valid_modes=['fan_in','fan_out']ifmodenotinvalid_modes:raiseValueError("Mode {} not supported, please use one of {}".format(mode,valid_modes))fan_in,fan_out=_calculate_fan_in_and_fan_out(tensor)returnfan_inifmode=='fan_in'elsefan_out
[docs]defkaiming_uniform_(tensor,a=0,mode='fan_in',nonlinearity='leaky_relu'):r"""Fills the input `Tensor` with values according to the method described in `Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification` - He, K. et al. (2015), using a uniform distribution. The resulting tensor will have values sampled from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where .. math:: \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} Also known as He initialization. Args: tensor: an n-dimensional `torch.Tensor` a: the negative slope of the rectifier used after this layer (only used with ``'leaky_relu'``) mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` preserves the magnitude of the variance of the weights in the forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the backwards pass. nonlinearity: the non-linear function (`nn.functional` name), recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). Examples: >>> w = torch.empty(3, 5) >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu') """iftorch.overrides.has_torch_function_variadic(tensor):returntorch.overrides.handle_torch_function(kaiming_uniform_,(tensor,),tensor=tensor,a=a,mode=mode,nonlinearity=nonlinearity)if0intensor.shape:warnings.warn("Initializing zero-element tensors is a no-op")returntensorfan=_calculate_correct_fan(tensor,mode)gain=calculate_gain(nonlinearity,a)std=gain/math.sqrt(fan)bound=math.sqrt(3.0)*std# Calculate uniform bounds from standard deviationwithtorch.no_grad():returntensor.uniform_(-bound,bound)
[docs]defkaiming_normal_(tensor,a=0,mode='fan_in',nonlinearity='leaky_relu'):r"""Fills the input `Tensor` with values according to the method described in `Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification` - He, K. et al. (2015), using a normal distribution. The resulting tensor will have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where .. math:: \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}} Also known as He initialization. Args: tensor: an n-dimensional `torch.Tensor` a: the negative slope of the rectifier used after this layer (only used with ``'leaky_relu'``) mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` preserves the magnitude of the variance of the weights in the forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the backwards pass. nonlinearity: the non-linear function (`nn.functional` name), recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). Examples: >>> w = torch.empty(3, 5) >>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu') """if0intensor.shape:warnings.warn("Initializing zero-element tensors is a no-op")returntensorfan=_calculate_correct_fan(tensor,mode)gain=calculate_gain(nonlinearity,a)std=gain/math.sqrt(fan)withtorch.no_grad():returntensor.normal_(0,std)
[docs]deforthogonal_(tensor,gain=1):r"""Fills the input `Tensor` with a (semi) orthogonal matrix, as described in `Exact solutions to the nonlinear dynamics of learning in deep linear neural networks` - Saxe, A. et al. (2013). The input tensor must have at least 2 dimensions, and for tensors with more than 2 dimensions the trailing dimensions are flattened. Args: tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2` gain: optional scaling factor Examples: >>> w = torch.empty(3, 5) >>> nn.init.orthogonal_(w) """iftensor.ndimension()<2:raiseValueError("Only tensors with 2 or more dimensions are supported")rows=tensor.size(0)cols=tensor.numel()//rowsflattened=tensor.new(rows,cols).normal_(0,1)ifrows<cols:flattened.t_()# Compute the qr factorizationq,r=torch.linalg.qr(flattened)# Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdfd=torch.diag(r,0)ph=d.sign()q*=phifrows<cols:q.t_()withtorch.no_grad():tensor.view_as(q).copy_(q)tensor.mul_(gain)returntensor
[docs]defsparse_(tensor,sparsity,std=0.01):r"""Fills the 2D input `Tensor` as a sparse matrix, where the non-zero elements will be drawn from the normal distribution :math:`\mathcal{N}(0, 0.01)`, as described in `Deep learning via Hessian-free optimization` - Martens, J. (2010). Args: tensor: an n-dimensional `torch.Tensor` sparsity: The fraction of elements in each column to be set to zero std: the standard deviation of the normal distribution used to generate the non-zero values Examples: >>> w = torch.empty(3, 5) >>> nn.init.sparse_(w, sparsity=0.1) """iftensor.ndimension()!=2:raiseValueError("Only tensors with 2 dimensions are supported")rows,cols=tensor.shapenum_zeros=int(math.ceil(sparsity*rows))withtorch.no_grad():tensor.normal_(0,std)forcol_idxinrange(cols):row_indices=torch.randperm(rows)zero_indices=row_indices[:num_zeros]tensor[zero_indices,col_idx]=0returntensor
# for backward compatibilitydef_make_deprecate(meth):new_name=meth.__name__old_name=new_name[:-1]defdeprecated_init(*args,**kwargs):warnings.warn("nn.init.{} is now deprecated in favor of nn.init.{}.".format(old_name,new_name),stacklevel=2)returnmeth(*args,**kwargs)deprecated_init.__doc__=r"""{old_name}(...) .. warning:: This method is now deprecated in favor of :func:`torch.nn.init.{new_name}`. See :func:`~torch.nn.init.{new_name}` for details.""".format(old_name=old_name,new_name=new_name)deprecated_init.__name__=old_namereturndeprecated_inituniform=_make_deprecate(uniform_)normal=_make_deprecate(normal_)constant=_make_deprecate(constant_)eye=_make_deprecate(eye_)dirac=_make_deprecate(dirac_)xavier_uniform=_make_deprecate(xavier_uniform_)xavier_normal=_make_deprecate(xavier_normal_)kaiming_uniform=_make_deprecate(kaiming_uniform_)kaiming_normal=_make_deprecate(kaiming_normal_)orthogonal=_make_deprecate(orthogonal_)sparse=_make_deprecate(sparse_)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.