[docs]classFrozenBatchNorm2d(torch.nn.Module):""" BatchNorm2d where the batch statistics and the affine parameters are fixed Args: num_features (int): Number of features ``C`` from an expected input of size ``(N, C, H, W)`` eps (float): a value added to the denominator for numerical stability. Default: 1e-5 """def__init__(self,num_features:int,eps:float=1e-5,):super().__init__()_log_api_usage_once(self)self.eps=epsself.register_buffer("weight",torch.ones(num_features))self.register_buffer("bias",torch.zeros(num_features))self.register_buffer("running_mean",torch.zeros(num_features))self.register_buffer("running_var",torch.ones(num_features))def_load_from_state_dict(self,state_dict:dict,prefix:str,local_metadata:dict,strict:bool,missing_keys:List[str],unexpected_keys:List[str],error_msgs:List[str],):num_batches_tracked_key=prefix+"num_batches_tracked"ifnum_batches_tracked_keyinstate_dict:delstate_dict[num_batches_tracked_key]super()._load_from_state_dict(state_dict,prefix,local_metadata,strict,missing_keys,unexpected_keys,error_msgs)
[docs]defforward(self,x:Tensor)->Tensor:# move reshapes to the beginning# to make it fuser-friendlyw=self.weight.reshape(1,-1,1,1)b=self.bias.reshape(1,-1,1,1)rv=self.running_var.reshape(1,-1,1,1)rm=self.running_mean.reshape(1,-1,1,1)scale=w*(rv+self.eps).rsqrt()bias=b-rm*scalereturnx*scale+bias
classConvNormActivation(torch.nn.Sequential):def__init__(self,in_channels:int,out_channels:int,kernel_size:Union[int,Tuple[int,...]]=3,stride:Union[int,Tuple[int,...]]=1,padding:Optional[Union[int,Tuple[int,...],str]]=None,groups:int=1,norm_layer:Optional[Callable[...,torch.nn.Module]]=torch.nn.BatchNorm2d,activation_layer:Optional[Callable[...,torch.nn.Module]]=torch.nn.ReLU,dilation:Union[int,Tuple[int,...]]=1,inplace:Optional[bool]=True,bias:Optional[bool]=None,conv_layer:Callable[...,torch.nn.Module]=torch.nn.Conv2d,)->None:ifpaddingisNone:ifisinstance(kernel_size,int)andisinstance(dilation,int):padding=(kernel_size-1)//2*dilationelse:_conv_dim=len(kernel_size)ifisinstance(kernel_size,Sequence)elselen(dilation)kernel_size=_make_ntuple(kernel_size,_conv_dim)dilation=_make_ntuple(dilation,_conv_dim)padding=tuple((kernel_size[i]-1)//2*dilation[i]foriinrange(_conv_dim))ifbiasisNone:bias=norm_layerisNonelayers=[conv_layer(in_channels,out_channels,kernel_size,stride,padding,dilation=dilation,groups=groups,bias=bias,)]ifnorm_layerisnotNone:layers.append(norm_layer(out_channels))ifactivation_layerisnotNone:params={}ifinplaceisNoneelse{"inplace":inplace}layers.append(activation_layer(**params))super().__init__(*layers)_log_api_usage_once(self)self.out_channels=out_channelsifself.__class__==ConvNormActivation:warnings.warn("Don't use ConvNormActivation directly, please use Conv2dNormActivation and Conv3dNormActivation instead.")
[docs]classConv2dNormActivation(ConvNormActivation):""" Configurable block used for Convolution2d-Normalization-Activation blocks. Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block kernel_size: (int, optional): Size of the convolving kernel. Default: 3 stride (int, optional): Stride of the convolution. Default: 1 padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will be calculated as ``padding = (kernel_size - 1) // 2 * dilation`` groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer won't be used. Default: ``torch.nn.BatchNorm2d`` activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU`` dilation (int): Spacing between kernel elements. Default: 1 inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``. """def__init__(self,in_channels:int,out_channels:int,kernel_size:Union[int,Tuple[int,int]]=3,stride:Union[int,Tuple[int,int]]=1,padding:Optional[Union[int,Tuple[int,int],str]]=None,groups:int=1,norm_layer:Optional[Callable[...,torch.nn.Module]]=torch.nn.BatchNorm2d,activation_layer:Optional[Callable[...,torch.nn.Module]]=torch.nn.ReLU,dilation:Union[int,Tuple[int,int]]=1,inplace:Optional[bool]=True,bias:Optional[bool]=None,)->None:super().__init__(in_channels,out_channels,kernel_size,stride,padding,groups,norm_layer,activation_layer,dilation,inplace,bias,torch.nn.Conv2d,)
[docs]classConv3dNormActivation(ConvNormActivation):""" Configurable block used for Convolution3d-Normalization-Activation blocks. Args: in_channels (int): Number of channels in the input video. out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block kernel_size: (int, optional): Size of the convolving kernel. Default: 3 stride (int, optional): Stride of the convolution. Default: 1 padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will be calculated as ``padding = (kernel_size - 1) // 2 * dilation`` groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer won't be used. Default: ``torch.nn.BatchNorm3d`` activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU`` dilation (int): Spacing between kernel elements. Default: 1 inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``. """def__init__(self,in_channels:int,out_channels:int,kernel_size:Union[int,Tuple[int,int,int]]=3,stride:Union[int,Tuple[int,int,int]]=1,padding:Optional[Union[int,Tuple[int,int,int],str]]=None,groups:int=1,norm_layer:Optional[Callable[...,torch.nn.Module]]=torch.nn.BatchNorm3d,activation_layer:Optional[Callable[...,torch.nn.Module]]=torch.nn.ReLU,dilation:Union[int,Tuple[int,int,int]]=1,inplace:Optional[bool]=True,bias:Optional[bool]=None,)->None:super().__init__(in_channels,out_channels,kernel_size,stride,padding,groups,norm_layer,activation_layer,dilation,inplace,bias,torch.nn.Conv3d,)
[docs]classSqueezeExcitation(torch.nn.Module):""" This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507 (see Fig. 1). Parameters ``activation``, and ``scale_activation`` correspond to ``delta`` and ``sigma`` in eq. 3. Args: input_channels (int): Number of channels in the input image squeeze_channels (int): Number of squeeze channels activation (Callable[..., torch.nn.Module], optional): ``delta`` activation. Default: ``torch.nn.ReLU`` scale_activation (Callable[..., torch.nn.Module]): ``sigma`` activation. Default: ``torch.nn.Sigmoid`` """def__init__(self,input_channels:int,squeeze_channels:int,activation:Callable[...,torch.nn.Module]=torch.nn.ReLU,scale_activation:Callable[...,torch.nn.Module]=torch.nn.Sigmoid,)->None:super().__init__()_log_api_usage_once(self)self.avgpool=torch.nn.AdaptiveAvgPool2d(1)self.fc1=torch.nn.Conv2d(input_channels,squeeze_channels,1)self.fc2=torch.nn.Conv2d(squeeze_channels,input_channels,1)self.activation=activation()self.scale_activation=scale_activation()def_scale(self,input:Tensor)->Tensor:scale=self.avgpool(input)scale=self.fc1(scale)scale=self.activation(scale)scale=self.fc2(scale)returnself.scale_activation(scale)
[docs]classMLP(torch.nn.Sequential):"""This block implements the multi-layer perceptron (MLP) module. Args: in_channels (int): Number of channels of the input hidden_channels (List[int]): List of the hidden channel dimensions norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the linear layer. If ``None`` this layer won't be used. Default: ``None`` activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU`` inplace (bool, optional): Parameter for the activation layer, which can optionally do the operation in-place. Default is ``None``, which uses the respective default values of the ``activation_layer`` and Dropout layer. bias (bool): Whether to use bias in the linear layer. Default ``True`` dropout (float): The probability for the dropout layer. Default: 0.0 """def__init__(self,in_channels:int,hidden_channels:List[int],norm_layer:Optional[Callable[...,torch.nn.Module]]=None,activation_layer:Optional[Callable[...,torch.nn.Module]]=torch.nn.ReLU,inplace:Optional[bool]=None,bias:bool=True,dropout:float=0.0,):# The addition of `norm_layer` is inspired from the implementation of TorchMultimodal:# https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.pyparams={}ifinplaceisNoneelse{"inplace":inplace}layers=[]in_dim=in_channelsforhidden_diminhidden_channels[:-1]:layers.append(torch.nn.Linear(in_dim,hidden_dim,bias=bias))ifnorm_layerisnotNone:layers.append(norm_layer(hidden_dim))layers.append(activation_layer(**params))layers.append(torch.nn.Dropout(dropout,**params))in_dim=hidden_dimlayers.append(torch.nn.Linear(in_dim,hidden_channels[-1],bias=bias))layers.append(torch.nn.Dropout(dropout,**params))super().__init__(*layers)_log_api_usage_once(self)
[docs]classPermute(torch.nn.Module):"""This module returns a view of the tensor input with its dimensions permuted. Args: dims (List[int]): The desired ordering of dimensions """def__init__(self,dims:List[int]):super().__init__()self.dims=dims
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.