torchtnt.utils.distributed.revert_sync_batchnorm¶
-
torchtnt.utils.distributed.
revert_sync_batchnorm
(module: Module, device: Optional[Union[str, device]] = None) Module ¶ Helper function to convert all
torch.nn.SyncBatchNorm
layers in the module toBatchNorm*D
layers. This function revertstorch.nn.SyncBatchNorm.convert_sync_batchnorm()
.Parameters: - module (nn.Module) – module containing one or more
torch.nn.SyncBatchNorm
layers - device (optional) – device in which the
BatchNorm*D
should be created, default is cpu
Returns: The original
module
with the convertedBatchNorm*D
layers. If the originalmodule
is atorch.nn.SyncBatchNorm
layer, a newBatchNorm*D
layer object will be returned instead. Note that theBatchNorm*D
layers returned will not have input dimension information.Example:
>>> # Network with nn.BatchNorm layer >>> module = torch.nn.Sequential( >>> torch.nn.Linear(20, 100), >>> torch.nn.BatchNorm1d(100), >>> ).cuda() >>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module) >>> reverted_module = revert_sync_batchnorm(sync_bn_module, torch.device("cuda"))
- module (nn.Module) – module containing one or more