importtorchfromtorch.nnimportConv1d,Conv2d,Conv3d,ReLU,Linear,BatchNorm1d,BatchNorm2d,BatchNorm3d# Used for identifying intrinsic modules used in quantizationclass_FusedModule(torch.nn.Sequential):pass
[docs]classConvReLU1d(_FusedModule):r"""This is a sequential container which calls the Conv1d and ReLU modules. During quantization this will be replaced with the corresponding fused module."""def__init__(self,conv,relu):asserttype(conv)==Conv1dandtype(relu)==ReLU, \
'Incorrect types for input modules{}{}'.format(type(conv),type(relu))super().__init__(conv,relu)
[docs]classConvReLU2d(_FusedModule):r"""This is a sequential container which calls the Conv2d and ReLU modules. During quantization this will be replaced with the corresponding fused module."""def__init__(self,conv,relu):asserttype(conv)==Conv2dandtype(relu)==ReLU, \
'Incorrect types for input modules{}{}'.format(type(conv),type(relu))super().__init__(conv,relu)
[docs]classConvReLU3d(_FusedModule):r"""This is a sequential container which calls the Conv3d and ReLU modules. During quantization this will be replaced with the corresponding fused module."""def__init__(self,conv,relu):asserttype(conv)==Conv3dandtype(relu)==ReLU, \
'Incorrect types for input modules{}{}'.format(type(conv),type(relu))super().__init__(conv,relu)
classLinearReLU(_FusedModule):r"""This is a sequential container which calls the Linear and ReLU modules. During quantization this will be replaced with the corresponding fused module."""def__init__(self,linear,relu):asserttype(linear)==Linearandtype(relu)==ReLU, \
'Incorrect types for input modules{}{}'.format(type(linear),type(relu))super().__init__(linear,relu)
[docs]classConvBn1d(_FusedModule):r"""This is a sequential container which calls the Conv 1d and Batch Norm 1d modules. During quantization this will be replaced with the corresponding fused module."""def__init__(self,conv,bn):asserttype(conv)==Conv1dandtype(bn)==BatchNorm1d, \
'Incorrect types for input modules{}{}'.format(type(conv),type(bn))super().__init__(conv,bn)
[docs]classConvBn2d(_FusedModule):r"""This is a sequential container which calls the Conv 2d and Batch Norm 2d modules. During quantization this will be replaced with the corresponding fused module."""def__init__(self,conv,bn):asserttype(conv)==Conv2dandtype(bn)==BatchNorm2d, \
'Incorrect types for input modules{}{}'.format(type(conv),type(bn))super(ConvBn2d,self).__init__(conv,bn)
[docs]classConvBnReLU1d(_FusedModule):r"""This is a sequential container which calls the Conv 1d, Batch Norm 1d, and ReLU modules. During quantization this will be replaced with the corresponding fused module."""def__init__(self,conv,bn,relu):asserttype(conv)==Conv1dandtype(bn)==BatchNorm1dand \
type(relu)==ReLU,'Incorrect types for input modules{}{}{}' \
.format(type(conv),type(bn),type(relu))super().__init__(conv,bn,relu)
[docs]classConvBnReLU2d(_FusedModule):r"""This is a sequential container which calls the Conv 2d, Batch Norm 2d, and ReLU modules. During quantization this will be replaced with the corresponding fused module."""def__init__(self,conv,bn,relu):asserttype(conv)==Conv2dandtype(bn)==BatchNorm2dand \
type(relu)==ReLU,'Incorrect types for input modules{}{}{}' \
.format(type(conv),type(bn),type(relu))super().__init__(conv,bn,relu)
[docs]classConvBn3d(_FusedModule):r"""This is a sequential container which calls the Conv 3d and Batch Norm 3d modules. During quantization this will be replaced with the corresponding fused module."""def__init__(self,conv,bn):asserttype(conv)==Conv3dandtype(bn)==BatchNorm3d, \
'Incorrect types for input modules{}{}'.format(type(conv),type(bn))super().__init__(conv,bn)
[docs]classConvBnReLU3d(_FusedModule):r"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules. During quantization this will be replaced with the corresponding fused module."""def__init__(self,conv,bn,relu):asserttype(conv)==Conv3dandtype(bn)==BatchNorm3dand \
type(relu)==ReLU,'Incorrect types for input modules{}{}{}' \
.format(type(conv),type(bn),type(relu))super().__init__(conv,bn,relu)
classBNReLU2d(_FusedModule):r"""This is a sequential container which calls the BatchNorm 2d and ReLU modules. During quantization this will be replaced with the corresponding fused module."""def__init__(self,batch_norm,relu):asserttype(batch_norm)==BatchNorm2dandtype(relu)==ReLU, \
'Incorrect types for input modules{}{}'.format(type(batch_norm),type(relu))super().__init__(batch_norm,relu)classBNReLU3d(_FusedModule):r"""This is a sequential container which calls the BatchNorm 3d and ReLU modules. During quantization this will be replaced with the corresponding fused module."""def__init__(self,batch_norm,relu):asserttype(batch_norm)==BatchNorm3dandtype(relu)==ReLU, \
'Incorrect types for input modules{}{}'.format(type(batch_norm),type(relu))super().__init__(batch_norm,relu)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.