fromtypingimportOptional,AnyimporttorchfromtorchimportTensorfromtorch.nn.parameterimportParameter,UninitializedParameter,UninitializedBufferfrom..importfunctionalasFfrom..importinitfrom._functionsimportSyncBatchNormassync_batch_normfrom.lazyimportLazyModuleMixinfrom.moduleimportModule__all__=['BatchNorm1d','LazyBatchNorm1d','BatchNorm2d','LazyBatchNorm2d','BatchNorm3d','LazyBatchNorm3d','SyncBatchNorm']class_NormBase(Module):"""Common base of _InstanceNorm and _BatchNorm"""_version=2__constants__=["track_running_stats","momentum","eps","num_features","affine"]num_features:inteps:floatmomentum:floataffine:booltrack_running_stats:bool# WARNING: weight and bias purposely not defined here.# See https://github.com/pytorch/pytorch/issues/39670def__init__(self,num_features:int,eps:float=1e-5,momentum:float=0.1,affine:bool=True,track_running_stats:bool=True,device=None,dtype=None)->None:factory_kwargs={'device':device,'dtype':dtype}super().__init__()self.num_features=num_featuresself.eps=epsself.momentum=momentumself.affine=affineself.track_running_stats=track_running_statsifself.affine:self.weight=Parameter(torch.empty(num_features,**factory_kwargs))self.bias=Parameter(torch.empty(num_features,**factory_kwargs))else:self.register_parameter("weight",None)self.register_parameter("bias",None)ifself.track_running_stats:self.register_buffer('running_mean',torch.zeros(num_features,**factory_kwargs))self.register_buffer('running_var',torch.ones(num_features,**factory_kwargs))self.running_mean:Optional[Tensor]self.running_var:Optional[Tensor]self.register_buffer('num_batches_tracked',torch.tensor(0,dtype=torch.long,**{k:vfork,vinfactory_kwargs.items()ifk!='dtype'}))self.num_batches_tracked:Optional[Tensor]else:self.register_buffer("running_mean",None)self.register_buffer("running_var",None)self.register_buffer("num_batches_tracked",None)self.reset_parameters()defreset_running_stats(self)->None:ifself.track_running_stats:# running_mean/running_var/num_batches... are registered at runtime depending# if self.track_running_stats is onself.running_mean.zero_()# type: ignore[union-attr]self.running_var.fill_(1)# type: ignore[union-attr]self.num_batches_tracked.zero_()# type: ignore[union-attr,operator]defreset_parameters(self)->None:self.reset_running_stats()ifself.affine:init.ones_(self.weight)init.zeros_(self.bias)def_check_input_dim(self,input):raiseNotImplementedErrordefextra_repr(self):return("{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ""track_running_stats={track_running_stats}".format(**self.__dict__))def_load_from_state_dict(self,state_dict,prefix,local_metadata,strict,missing_keys,unexpected_keys,error_msgs,):version=local_metadata.get("version",None)if(versionisNoneorversion<2)andself.track_running_stats:# at version 2: added num_batches_tracked buffer# this should have a default value of 0num_batches_tracked_key=prefix+"num_batches_tracked"ifnum_batches_tracked_keynotinstate_dict:state_dict[num_batches_tracked_key]=torch.tensor(0,dtype=torch.long)super()._load_from_state_dict(state_dict,prefix,local_metadata,strict,missing_keys,unexpected_keys,error_msgs,)class_BatchNorm(_NormBase):def__init__(self,num_features:int,eps:float=1e-5,momentum:float=0.1,affine:bool=True,track_running_stats:bool=True,device=None,dtype=None)->None:factory_kwargs={'device':device,'dtype':dtype}super().__init__(num_features,eps,momentum,affine,track_running_stats,**factory_kwargs)defforward(self,input:Tensor)->Tensor:self._check_input_dim(input)# exponential_average_factor is set to self.momentum# (when it is available) only so that it gets updated# in ONNX graph when this node is exported to ONNX.ifself.momentumisNone:exponential_average_factor=0.0else:exponential_average_factor=self.momentumifself.trainingandself.track_running_stats:# TODO: if statement only here to tell the jit to skip emitting this when it is Noneifself.num_batches_trackedisnotNone:# type: ignore[has-type]self.num_batches_tracked.add_(1)# type: ignore[has-type]ifself.momentumisNone:# use cumulative moving averageexponential_average_factor=1.0/float(self.num_batches_tracked)else:# use exponential moving averageexponential_average_factor=self.momentumr""" Decide whether the mini-batch stats should be used for normalization rather than the buffers. Mini-batch stats are used in training mode, and in eval mode when buffers are None. """ifself.training:bn_training=Trueelse:bn_training=(self.running_meanisNone)and(self.running_varisNone)r""" Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are used for normalization (i.e. in eval mode when buffers are not None). """returnF.batch_norm(input,# If buffers are not to be tracked, ensure that they won't be updatedself.running_meanifnotself.trainingorself.track_running_statselseNone,self.running_varifnotself.trainingorself.track_running_statselseNone,self.weight,self.bias,bn_training,exponential_average_factor,self.eps,)class_LazyNormBase(LazyModuleMixin,_NormBase):weight:UninitializedParameter# type: ignore[assignment]bias:UninitializedParameter# type: ignore[assignment]def__init__(self,eps=1e-5,momentum=0.1,affine=True,track_running_stats=True,device=None,dtype=None)->None:factory_kwargs={'device':device,'dtype':dtype}super().__init__(# affine and track_running_stats are hardcoded to False to# avoid creating tensors that will soon be overwritten.0,eps,momentum,False,False,**factory_kwargs,)self.affine=affineself.track_running_stats=track_running_statsifself.affine:self.weight=UninitializedParameter(**factory_kwargs)self.bias=UninitializedParameter(**factory_kwargs)ifself.track_running_stats:self.running_mean=UninitializedBuffer(**factory_kwargs)self.running_var=UninitializedBuffer(**factory_kwargs)self.num_batches_tracked=torch.tensor(0,dtype=torch.long,**{k:vfork,vinfactory_kwargs.items()ifk!='dtype'})defreset_parameters(self)->None:ifnotself.has_uninitialized_params()andself.num_features!=0:super().reset_parameters()definitialize_parameters(self,input)->None:# type: ignore[override]ifself.has_uninitialized_params():self.num_features=input.shape[1]ifself.affine:assertisinstance(self.weight,UninitializedParameter)assertisinstance(self.bias,UninitializedParameter)self.weight.materialize((self.num_features,))self.bias.materialize((self.num_features,))ifself.track_running_stats:self.running_mean.materialize((self.num_features,))# type:ignore[union-attr]self.running_var.materialize((self.num_features,))# type:ignore[union-attr]self.reset_parameters()
[docs]classBatchNorm1d(_BatchNorm):r"""Applies Batch Normalization over a 2D or 3D input as described in the paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ . .. math:: y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta The mean and standard-deviation are calculated per-dimension over the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors of size `C` (where `C` is the number of features or channels of the input). By default, the elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the standard-deviation is calculated via the biased estimator, equivalent to ``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the standard-deviation is calculated via the unbiased estimator, equivalent to ``torch.var(input, unbiased=True)``. Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default :attr:`momentum` of 0.1. If :attr:`track_running_stats` is set to ``False``, this layer then does not keep running estimates, and batch statistics are instead used during evaluation time as well. .. note:: This :attr:`momentum` argument is different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value. Because the Batch Normalization is done over the `C` dimension, computing statistics on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization. Args: num_features: number of features or channels :math:`C` of the input eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Can be set to ``None`` for cumulative moving average (i.e. simple average). Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters. Default: ``True`` track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics, and initializes statistics buffers :attr:`running_mean` and :attr:`running_var` as ``None``. When these buffers are ``None``, this module always uses batch statistics. in both training and eval modes. Default: ``True`` Shape: - Input: :math:`(N, C)` or :math:`(N, C, L)`, where :math:`N` is the batch size, :math:`C` is the number of features or channels, and :math:`L` is the sequence length - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) Examples:: >>> # With Learnable Parameters >>> m = nn.BatchNorm1d(100) >>> # Without Learnable Parameters >>> m = nn.BatchNorm1d(100, affine=False) >>> input = torch.randn(20, 100) >>> output = m(input) """def_check_input_dim(self,input):ifinput.dim()!=2andinput.dim()!=3:raiseValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
[docs]classLazyBatchNorm1d(_LazyNormBase,_BatchNorm):r"""A :class:`torch.nn.BatchNorm1d` module with lazy initialization of the ``num_features`` argument of the :class:`BatchNorm1d` that is inferred from the ``input.size(1)``. The attributes that will be lazily initialized are `weight`, `bias`, `running_mean` and `running_var`. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation on lazy modules and their limitations. Args: eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Can be set to ``None`` for cumulative moving average (i.e. simple average). Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters. Default: ``True`` track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics, and initializes statistics buffers :attr:`running_mean` and :attr:`running_var` as ``None``. When these buffers are ``None``, this module always uses batch statistics. in both training and eval modes. Default: ``True`` """cls_to_become=BatchNorm1d# type: ignore[assignment]def_check_input_dim(self,input):ifinput.dim()!=2andinput.dim()!=3:raiseValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
[docs]classBatchNorm2d(_BatchNorm):r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension) as described in the paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ . .. math:: y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta The mean and standard-deviation are calculated per-dimension over the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the standard-deviation is calculated via the biased estimator, equivalent to ``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the standard-deviation is calculated via the unbiased estimator, equivalent to ``torch.var(input, unbiased=True)``. Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default :attr:`momentum` of 0.1. If :attr:`track_running_stats` is set to ``False``, this layer then does not keep running estimates, and batch statistics are instead used during evaluation time as well. .. note:: This :attr:`momentum` argument is different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value. Because the Batch Normalization is done over the `C` dimension, computing statistics on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization. Args: num_features: :math:`C` from an expected input of size :math:`(N, C, H, W)` eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Can be set to ``None`` for cumulative moving average (i.e. simple average). Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters. Default: ``True`` track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics, and initializes statistics buffers :attr:`running_mean` and :attr:`running_var` as ``None``. When these buffers are ``None``, this module always uses batch statistics. in both training and eval modes. Default: ``True`` Shape: - Input: :math:`(N, C, H, W)` - Output: :math:`(N, C, H, W)` (same shape as input) Examples:: >>> # With Learnable Parameters >>> m = nn.BatchNorm2d(100) >>> # Without Learnable Parameters >>> m = nn.BatchNorm2d(100, affine=False) >>> input = torch.randn(20, 100, 35, 45) >>> output = m(input) """def_check_input_dim(self,input):ifinput.dim()!=4:raiseValueError(f"expected 4D input (got {input.dim()}D input)")
[docs]classLazyBatchNorm2d(_LazyNormBase,_BatchNorm):r"""A :class:`torch.nn.BatchNorm2d` module with lazy initialization of the ``num_features`` argument of the :class:`BatchNorm2d` that is inferred from the ``input.size(1)``. The attributes that will be lazily initialized are `weight`, `bias`, `running_mean` and `running_var`. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation on lazy modules and their limitations. Args: eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Can be set to ``None`` for cumulative moving average (i.e. simple average). Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters. Default: ``True`` track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics, and initializes statistics buffers :attr:`running_mean` and :attr:`running_var` as ``None``. When these buffers are ``None``, this module always uses batch statistics. in both training and eval modes. Default: ``True`` """cls_to_become=BatchNorm2d# type: ignore[assignment]def_check_input_dim(self,input):ifinput.dim()!=4:raiseValueError(f"expected 4D input (got {input.dim()}D input)")
[docs]classBatchNorm3d(_BatchNorm):r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs with additional channel dimension) as described in the paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ . .. math:: y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta The mean and standard-deviation are calculated per-dimension over the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the standard-deviation is calculated via the biased estimator, equivalent to ``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the standard-deviation is calculated via the unbiased estimator, equivalent to ``torch.var(input, unbiased=True)``. Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default :attr:`momentum` of 0.1. If :attr:`track_running_stats` is set to ``False``, this layer then does not keep running estimates, and batch statistics are instead used during evaluation time as well. .. note:: This :attr:`momentum` argument is different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value. Because the Batch Normalization is done over the `C` dimension, computing statistics on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization or Spatio-temporal Batch Normalization. Args: num_features: :math:`C` from an expected input of size :math:`(N, C, D, H, W)` eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Can be set to ``None`` for cumulative moving average (i.e. simple average). Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters. Default: ``True`` track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics, and initializes statistics buffers :attr:`running_mean` and :attr:`running_var` as ``None``. When these buffers are ``None``, this module always uses batch statistics. in both training and eval modes. Default: ``True`` Shape: - Input: :math:`(N, C, D, H, W)` - Output: :math:`(N, C, D, H, W)` (same shape as input) Examples:: >>> # With Learnable Parameters >>> m = nn.BatchNorm3d(100) >>> # Without Learnable Parameters >>> m = nn.BatchNorm3d(100, affine=False) >>> input = torch.randn(20, 100, 35, 45, 10) >>> output = m(input) """def_check_input_dim(self,input):ifinput.dim()!=5:raiseValueError(f"expected 5D input (got {input.dim()}D input)")
[docs]classLazyBatchNorm3d(_LazyNormBase,_BatchNorm):r"""A :class:`torch.nn.BatchNorm3d` module with lazy initialization of the ``num_features`` argument of the :class:`BatchNorm3d` that is inferred from the ``input.size(1)``. The attributes that will be lazily initialized are `weight`, `bias`, `running_mean` and `running_var`. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation on lazy modules and their limitations. Args: eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Can be set to ``None`` for cumulative moving average (i.e. simple average). Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters. Default: ``True`` track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics, and initializes statistics buffers :attr:`running_mean` and :attr:`running_var` as ``None``. When these buffers are ``None``, this module always uses batch statistics. in both training and eval modes. Default: ``True`` """cls_to_become=BatchNorm3d# type: ignore[assignment]def_check_input_dim(self,input):ifinput.dim()!=5:raiseValueError(f"expected 5D input (got {input.dim()}D input)")
[docs]classSyncBatchNorm(_BatchNorm):r"""Applies Batch Normalization over a N-Dimensional input (a mini-batch of [N-2]D inputs with additional channel dimension) as described in the paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ . .. math:: y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta The mean and standard-deviation are calculated per-dimension over all mini-batches of the same process groups. :math:`\gamma` and :math:`\beta` are learnable parameter vectors of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are sampled from :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0. The standard-deviation is calculated via the biased estimator, equivalent to `torch.var(input, unbiased=False)`. Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default :attr:`momentum` of 0.1. If :attr:`track_running_stats` is set to ``False``, this layer then does not keep running estimates, and batch statistics are instead used during evaluation time as well. .. note:: This :attr:`momentum` argument is different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value. Because the Batch Normalization is done for each channel in the ``C`` dimension, computing statistics on ``(N, +)`` slices, it's common terminology to call this Volumetric Batch Normalization or Spatio-temporal Batch Normalization. Currently :class:`SyncBatchNorm` only supports :class:`~torch.nn.DistributedDataParallel` (DDP) with single GPU per process. Use :meth:`torch.nn.SyncBatchNorm.convert_sync_batchnorm()` to convert :attr:`BatchNorm*D` layer to :class:`SyncBatchNorm` before wrapping Network with DDP. Args: num_features: :math:`C` from an expected input of size :math:`(N, C, +)` eps: a value added to the denominator for numerical stability. Default: ``1e-5`` momentum: the value used for the running_mean and running_var computation. Can be set to ``None`` for cumulative moving average (i.e. simple average). Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters. Default: ``True`` track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics, and initializes statistics buffers :attr:`running_mean` and :attr:`running_var` as ``None``. When these buffers are ``None``, this module always uses batch statistics. in both training and eval modes. Default: ``True`` process_group: synchronization of stats happen within each process group individually. Default behavior is synchronization across the whole world Shape: - Input: :math:`(N, C, +)` - Output: :math:`(N, C, +)` (same shape as input) .. note:: Synchronization of batchnorm statistics occurs only while training, i.e. synchronization is disabled when ``model.eval()`` is set or if ``self.training`` is otherwise ``False``. Examples:: >>> # xdoctest: +SKIP >>> # With Learnable Parameters >>> m = nn.SyncBatchNorm(100) >>> # creating process group (optional) >>> # ranks is a list of int identifying rank ids. >>> ranks = list(range(8)) >>> r1, r2 = ranks[:4], ranks[4:] >>> # Note: every rank calls into new_group for every >>> # process group created, even if that rank is not >>> # part of the group. >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]] >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1] >>> # Without Learnable Parameters >>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group) >>> input = torch.randn(20, 100, 35, 45, 10) >>> output = m(input) >>> # network is nn.BatchNorm layer >>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group) >>> # only single gpu per process is currently supported >>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel( >>> sync_bn_network, >>> device_ids=[args.local_rank], >>> output_device=args.local_rank) """def__init__(self,num_features:int,eps:float=1e-5,momentum:float=0.1,affine:bool=True,track_running_stats:bool=True,process_group:Optional[Any]=None,device=None,dtype=None)->None:factory_kwargs={'device':device,'dtype':dtype}super().__init__(num_features,eps,momentum,affine,track_running_stats,**factory_kwargs)self.process_group=process_groupdef_check_input_dim(self,input):ifinput.dim()<2:raiseValueError(f"expected at least 2D input (got {input.dim()}D input)")def_check_non_zero_input_channels(self,input):ifinput.size(1)==0:raiseValueError("SyncBatchNorm number of input channels should be non-zero")defforward(self,input:Tensor)->Tensor:self._check_input_dim(input)self._check_non_zero_input_channels(input)# exponential_average_factor is set to self.momentum# (when it is available) only so that it gets updated# in ONNX graph when this node is exported to ONNX.ifself.momentumisNone:exponential_average_factor=0.0else:exponential_average_factor=self.momentumifself.trainingandself.track_running_stats:assertself.num_batches_trackedisnotNoneself.num_batches_tracked.add_(1)ifself.momentumisNone:# use cumulative moving averageexponential_average_factor=1.0/self.num_batches_tracked.item()else:# use exponential moving averageexponential_average_factor=self.momentumr""" Decide whether the mini-batch stats should be used for normalization rather than the buffers. Mini-batch stats are used in training mode, and in eval mode when buffers are None. """ifself.training:bn_training=Trueelse:bn_training=(self.running_meanisNone)and(self.running_varisNone)r""" Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are used for normalization (i.e. in eval mode when buffers are not None). """# If buffers are not to be tracked, ensure that they won't be updatedrunning_mean=(self.running_meanifnotself.trainingorself.track_running_statselseNone)running_var=(self.running_varifnotself.trainingorself.track_running_statselseNone)# Don't sync batchnorm stats in inference mode (model.eval()).need_sync=(bn_trainingandself.trainingandtorch.distributed.is_available()andtorch.distributed.is_initialized())ifneed_sync:# currently only GPU/PrivateUse1 input is supportedifinput.device.typenotin["cuda",torch._C._get_privateuse1_backend_name()]:raiseValueError("SyncBatchNorm expected input tensor to be on GPU or "f"{torch._C._get_privateuse1_backend_name()}")process_group=torch.distributed.group.WORLDifself.process_group:process_group=self.process_groupworld_size=torch.distributed.get_world_size(process_group)need_sync=world_size>1# fallback to framework BN when synchronization is not necessaryifnotneed_sync:returnF.batch_norm(input,running_mean,running_var,self.weight,self.bias,bn_training,exponential_average_factor,self.eps,)else:assertbn_trainingreturnsync_batch_norm.apply(input,self.weight,self.bias,running_mean,running_var,self.eps,exponential_average_factor,process_group,world_size,)
[docs]@classmethoddefconvert_sync_batchnorm(cls,module,process_group=None):r"""Helper function to convert all :attr:`BatchNorm*D` layers in the model to :class:`torch.nn.SyncBatchNorm` layers. Args: module (nn.Module): module containing one or more :attr:`BatchNorm*D` layers process_group (optional): process group to scope synchronization, default is the whole world Returns: The original :attr:`module` with the converted :class:`torch.nn.SyncBatchNorm` layers. If the original :attr:`module` is a :attr:`BatchNorm*D` layer, a new :class:`torch.nn.SyncBatchNorm` layer object will be returned instead. Example:: >>> # Network with nn.BatchNorm layer >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) >>> module = torch.nn.Sequential( >>> torch.nn.Linear(20, 100), >>> torch.nn.BatchNorm1d(100), >>> ).cuda() >>> # creating process group (optional) >>> # ranks is a list of int identifying rank ids. >>> ranks = list(range(8)) >>> r1, r2 = ranks[:4], ranks[4:] >>> # Note: every rank calls into new_group for every >>> # process group created, even if that rank is not >>> # part of the group. >>> # xdoctest: +SKIP("distributed") >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]] >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1] >>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group) """module_output=moduleifisinstance(module,torch.nn.modules.batchnorm._BatchNorm):module_output=torch.nn.SyncBatchNorm(module.num_features,module.eps,module.momentum,module.affine,module.track_running_stats,process_group,)ifmodule.affine:withtorch.no_grad():module_output.weight=module.weightmodule_output.bias=module.biasmodule_output.running_mean=module.running_meanmodule_output.running_var=module.running_varmodule_output.num_batches_tracked=module.num_batches_trackedifhasattr(module,"qconfig"):module_output.qconfig=module.qconfigforname,childinmodule.named_children():module_output.add_module(name,cls.convert_sync_batchnorm(child,process_group))delmodulereturnmodule_output
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.