• Docs >
  • Utils >
  • torchtnt.utils.distributed.revert_sync_batchnorm
Shortcuts

torchtnt.utils.distributed.revert_sync_batchnorm

torchtnt.utils.distributed.revert_sync_batchnorm(module: Module, device: Optional[Union[str, device]] = None) Module

Helper function to convert all torch.nn.SyncBatchNorm layers in the module to BatchNorm*D layers. This function reverts torch.nn.SyncBatchNorm.convert_sync_batchnorm().

Parameters:
  • module (nn.Module) – module containing one or more torch.nn.SyncBatchNorm layers
  • device (optional) – device in which the BatchNorm*D should be created, default is cpu
Returns:

The original module with the converted BatchNorm*D layers. If the original module is a torch.nn.SyncBatchNorm layer, a new BatchNorm*D layer object will be returned instead. Note that the BatchNorm*D layers returned will not have input dimension information.

Example:

>>> # Network with nn.BatchNorm layer
>>> module = torch.nn.Sequential(
>>>            torch.nn.Linear(20, 100),
>>>            torch.nn.BatchNorm1d(100),
>>>          ).cuda()
>>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module)
>>> reverted_module = revert_sync_batchnorm(sync_bn_module, torch.device("cuda"))

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources