Shortcuts

Source code for torch.nn.init

from __future__ import division

import math
import warnings

import torch

# These no_grad_* functions are necessary as wrappers around the parts of these
# functions that use with torch.no_grad(). The JIT doesn't support context
# managers, so these need to be implemented as builtins. Using these wrappers
# lets us keep those builtins small and re-usable.
def _no_grad_uniform_(tensor, a, b):
return tensor.uniform_(a, b)

def _no_grad_normal_(tensor, mean, std):
return tensor.normal_(mean, std)

return tensor.fill_(val)

return tensor.zero_()

[docs]def calculate_gain(nonlinearity, param=None): r"""Return the recommended gain value for the given nonlinearity function. The values are as follows: ================= ==================================================== nonlinearity gain ================= ==================================================== Linear / Identity :math:1 Conv{1,2,3}D :math:1 Sigmoid :math:1 Tanh :math:\frac{5}{3} ReLU :math:\sqrt{2} Leaky Relu :math:\sqrt{\frac{2}{1 + \text{negative\_slope}^2}} ================= ==================================================== Args: nonlinearity: the non-linear function (nn.functional name) param: optional parameter for the non-linear function Examples: >>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 """ linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] if nonlinearity in linear_fns or nonlinearity == 'sigmoid': return 1 elif nonlinearity == 'tanh': return 5.0 / 3 elif nonlinearity == 'relu': return math.sqrt(2.0) elif nonlinearity == 'leaky_relu': if param is None: negative_slope = 0.01 elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): # True/False are instances of int, hence check above negative_slope = param else: raise ValueError("negative_slope {} not a valid number".format(param)) return math.sqrt(2.0 / (1 + negative_slope ** 2)) else: raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
[docs]def uniform_(tensor, a=0., b=1.): # type: (Tensor, float, float) -> Tensor r"""Fills the input Tensor with values drawn from the uniform distribution :math:\mathcal{U}(a, b). Args: tensor: an n-dimensional torch.Tensor a: the lower bound of the uniform distribution b: the upper bound of the uniform distribution Examples: >>> w = torch.empty(3, 5) >>> nn.init.uniform_(w) """ return _no_grad_uniform_(tensor, a, b)
[docs]def normal_(tensor, mean=0., std=1.): # type: (Tensor, float, float) -> Tensor r"""Fills the input Tensor with values drawn from the normal distribution :math:\mathcal{N}(\text{mean}, \text{std}^2). Args: tensor: an n-dimensional torch.Tensor mean: the mean of the normal distribution std: the standard deviation of the normal distribution Examples: >>> w = torch.empty(3, 5) >>> nn.init.normal_(w) """ return _no_grad_normal_(tensor, mean, std)
[docs]def constant_(tensor, val): # type: (Tensor, float) -> Tensor r"""Fills the input Tensor with the value :math:\text{val}. Args: tensor: an n-dimensional torch.Tensor val: the value to fill the tensor with Examples: >>> w = torch.empty(3, 5) >>> nn.init.constant_(w, 0.3) """ return _no_grad_fill_(tensor, val)
[docs]def ones_(tensor): # type: (Tensor) -> Tensor r"""Fills the input Tensor with the scalar value 1. Args: tensor: an n-dimensional torch.Tensor Examples: >>> w = torch.empty(3, 5) >>> nn.init.ones_(w) """ return _no_grad_fill_(tensor, 1.)
[docs]def zeros_(tensor): # type: (Tensor) -> Tensor r"""Fills the input Tensor with the scalar value 0. Args: tensor: an n-dimensional torch.Tensor Examples: >>> w = torch.empty(3, 5) >>> nn.init.zeros_(w) """ return _no_grad_zero_(tensor)
[docs]def eye_(tensor): r"""Fills the 2-dimensional input Tensor with the identity matrix. Preserves the identity of the inputs in Linear layers, where as many inputs are preserved as possible. Args: tensor: a 2-dimensional torch.Tensor Examples: >>> w = torch.empty(3, 5) >>> nn.init.eye_(w) """ if tensor.ndimension() != 2: raise ValueError("Only tensors with 2 dimensions are supported") with torch.no_grad(): torch.eye(*tensor.shape, out=tensor, requires_grad=tensor.requires_grad) return tensor
[docs]def dirac_(tensor): r"""Fills the {3, 4, 5}-dimensional input Tensor with the Dirac delta function. Preserves the identity of the inputs in Convolutional layers, where as many input channels are preserved as possible. Args: tensor: a {3, 4, 5}-dimensional torch.Tensor Examples: >>> w = torch.empty(3, 16, 5, 5) >>> nn.init.dirac_(w) """ dimensions = tensor.ndimension() if dimensions not in [3, 4, 5]: raise ValueError("Only tensors with 3, 4, or 5 dimensions are supported") sizes = tensor.size() min_dim = min(sizes, sizes) with torch.no_grad(): tensor.zero_() for d in range(min_dim): if dimensions == 3: # Temporal convolution tensor[d, d, tensor.size(2) // 2] = 1 elif dimensions == 4: # Spatial convolution tensor[d, d, tensor.size(2) // 2, tensor.size(3) // 2] = 1 else: # Volumetric convolution tensor[d, d, tensor.size(2) // 2, tensor.size(3) // 2, tensor.size(4) // 2] = 1 return tensor
def _calculate_fan_in_and_fan_out(tensor): dimensions = tensor.dim() if dimensions < 2: raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") if dimensions == 2: # Linear fan_in = tensor.size(1) fan_out = tensor.size(0) else: num_input_fmaps = tensor.size(1) num_output_fmaps = tensor.size(0) receptive_field_size = 1 if tensor.dim() > 2: receptive_field_size = tensor.numel() fan_in = num_input_fmaps * receptive_field_size fan_out = num_output_fmaps * receptive_field_size return fan_in, fan_out
[docs]def xavier_uniform_(tensor, gain=1.): # type: (Tensor, float) -> Tensor r"""Fills the input Tensor with values according to the method described in Understanding the difficulty of training deep feedforward neural networks - Glorot, X. & Bengio, Y. (2010), using a uniform distribution. The resulting tensor will have values sampled from :math:\mathcal{U}(-a, a) where .. math:: a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}} Also known as Glorot initialization. Args: tensor: an n-dimensional torch.Tensor gain: an optional scaling factor Examples: >>> w = torch.empty(3, 5) >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu')) """ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation return _no_grad_uniform_(tensor, -a, a)
[docs]def xavier_normal_(tensor, gain=1.): # type: (Tensor, float) -> Tensor r"""Fills the input Tensor with values according to the method described in Understanding the difficulty of training deep feedforward neural networks - Glorot, X. & Bengio, Y. (2010), using a normal distribution. The resulting tensor will have values sampled from :math:\mathcal{N}(0, \text{std}^2) where .. math:: \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}} Also known as Glorot initialization. Args: tensor: an n-dimensional torch.Tensor gain: an optional scaling factor Examples: >>> w = torch.empty(3, 5) >>> nn.init.xavier_normal_(w) """ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) return _no_grad_normal_(tensor, 0., std)
def _calculate_correct_fan(tensor, mode): mode = mode.lower() valid_modes = ['fan_in', 'fan_out'] if mode not in valid_modes: raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) return fan_in if mode == 'fan_in' else fan_out
[docs]def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'): r"""Fills the input Tensor with values according to the method described in Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification - He, K. et al. (2015), using a uniform distribution. The resulting tensor will have values sampled from :math:\mathcal{U}(-\text{bound}, \text{bound}) where .. math:: \text{bound} = \sqrt{\frac{6}{(1 + a^2) \times \text{fan\_in}}} Also known as He initialization. Args: tensor: an n-dimensional torch.Tensor a: the negative slope of the rectifier used after this layer (0 for ReLU by default) mode: either 'fan_in' (default) or 'fan_out'. Choosing 'fan_in' preserves the magnitude of the variance of the weights in the forward pass. Choosing 'fan_out' preserves the magnitudes in the backwards pass. nonlinearity: the non-linear function (nn.functional name), recommended to use only with 'relu' or 'leaky_relu' (default). Examples: >>> w = torch.empty(3, 5) >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu') """ fan = _calculate_correct_fan(tensor, mode) gain = calculate_gain(nonlinearity, a) std = gain / math.sqrt(fan) bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation with torch.no_grad(): return tensor.uniform_(-bound, bound)
[docs]def kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'): r"""Fills the input Tensor with values according to the method described in Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification - He, K. et al. (2015), using a normal distribution. The resulting tensor will have values sampled from :math:\mathcal{N}(0, \text{std}^2) where .. math:: \text{std} = \sqrt{\frac{2}{(1 + a^2) \times \text{fan\_in}}} Also known as He initialization. Args: tensor: an n-dimensional torch.Tensor a: the negative slope of the rectifier used after this layer (0 for ReLU by default) mode: either 'fan_in' (default) or 'fan_out'. Choosing 'fan_in' preserves the magnitude of the variance of the weights in the forward pass. Choosing 'fan_out' preserves the magnitudes in the backwards pass. nonlinearity: the non-linear function (nn.functional name), recommended to use only with 'relu' or 'leaky_relu' (default). Examples: >>> w = torch.empty(3, 5) >>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu') """ fan = _calculate_correct_fan(tensor, mode) gain = calculate_gain(nonlinearity, a) std = gain / math.sqrt(fan) with torch.no_grad(): return tensor.normal_(0, std)
[docs]def orthogonal_(tensor, gain=1): r"""Fills the input Tensor with a (semi) orthogonal matrix, as described in Exact solutions to the nonlinear dynamics of learning in deep linear neural networks - Saxe, A. et al. (2013). The input tensor must have at least 2 dimensions, and for tensors with more than 2 dimensions the trailing dimensions are flattened. Args: tensor: an n-dimensional torch.Tensor, where :math:n \geq 2 gain: optional scaling factor Examples: >>> w = torch.empty(3, 5) >>> nn.init.orthogonal_(w) """ if tensor.ndimension() < 2: raise ValueError("Only tensors with 2 or more dimensions are supported") rows = tensor.size(0) cols = tensor.numel() // rows flattened = tensor.new(rows, cols).normal_(0, 1) if rows < cols: flattened.t_() # Compute the qr factorization q, r = torch.qr(flattened) # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf d = torch.diag(r, 0) ph = d.sign() q *= ph if rows < cols: q.t_() with torch.no_grad(): tensor.view_as(q).copy_(q) tensor.mul_(gain) return tensor
[docs]def sparse_(tensor, sparsity, std=0.01): r"""Fills the 2D input Tensor as a sparse matrix, where the non-zero elements will be drawn from the normal distribution :math:\mathcal{N}(0, 0.01), as described in Deep learning via Hessian-free optimization - Martens, J. (2010). Args: tensor: an n-dimensional torch.Tensor sparsity: The fraction of elements in each column to be set to zero std: the standard deviation of the normal distribution used to generate the non-zero values Examples: >>> w = torch.empty(3, 5) >>> nn.init.sparse_(w, sparsity=0.1) """ if tensor.ndimension() != 2: raise ValueError("Only tensors with 2 dimensions are supported") rows, cols = tensor.shape num_zeros = int(math.ceil(sparsity * rows)) with torch.no_grad(): tensor.normal_(0, std) for col_idx in range(cols): row_indices = torch.randperm(rows) zero_indices = row_indices[:num_zeros] tensor[zero_indices, col_idx] = 0 return tensor
# for backward compatibility def _make_deprecate(meth): new_name = meth.__name__ old_name = new_name[:-1] def deprecated_init(*args, **kwargs): warnings.warn("nn.init.{} is now deprecated in favor of nn.init.{}." .format(old_name, new_name), stacklevel=2) return meth(*args, **kwargs) deprecated_init.__doc__ = r""" {old_name}(...) .. warning:: This method is now deprecated in favor of :func:torch.nn.init.{new_name}. See :func:~torch.nn.init.{new_name} for details.""".format( old_name=old_name, new_name=new_name) deprecated_init.__name__ = old_name return deprecated_init uniform = _make_deprecate(uniform_) normal = _make_deprecate(normal_) constant = _make_deprecate(constant_) eye = _make_deprecate(eye_) dirac = _make_deprecate(dirac_) xavier_uniform = _make_deprecate(xavier_uniform_) xavier_normal = _make_deprecate(xavier_normal_) kaiming_uniform = _make_deprecate(kaiming_uniform_) kaiming_normal = _make_deprecate(kaiming_normal_) orthogonal = _make_deprecate(orthogonal_) sparse = _make_deprecate(sparse_)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources