importrefromcollectionsimportOrderedDictfromfunctoolsimportpartialfromtypingimportAny,List,Optional,Tupleimporttorchimporttorch.nnasnnimporttorch.nn.functionalasFimporttorch.utils.checkpointascpfromtorchimportTensorfrom..transforms._presetsimportImageClassificationfrom..utilsimport_log_api_usage_oncefrom._apiimportregister_model,Weights,WeightsEnumfrom._metaimport_IMAGENET_CATEGORIESfrom._utilsimport_ovewrite_named_param,handle_legacy_interface__all__=["DenseNet","DenseNet121_Weights","DenseNet161_Weights","DenseNet169_Weights","DenseNet201_Weights","densenet121","densenet161","densenet169","densenet201",]class_DenseLayer(nn.Module):def__init__(self,num_input_features:int,growth_rate:int,bn_size:int,drop_rate:float,memory_efficient:bool=False)->None:super().__init__()self.norm1=nn.BatchNorm2d(num_input_features)self.relu1=nn.ReLU(inplace=True)self.conv1=nn.Conv2d(num_input_features,bn_size*growth_rate,kernel_size=1,stride=1,bias=False)self.norm2=nn.BatchNorm2d(bn_size*growth_rate)self.relu2=nn.ReLU(inplace=True)self.conv2=nn.Conv2d(bn_size*growth_rate,growth_rate,kernel_size=3,stride=1,padding=1,bias=False)self.drop_rate=float(drop_rate)self.memory_efficient=memory_efficientdefbn_function(self,inputs:List[Tensor])->Tensor:concated_features=torch.cat(inputs,1)bottleneck_output=self.conv1(self.relu1(self.norm1(concated_features)))# noqa: T484returnbottleneck_output# todo: rewrite when torchscript supports anydefany_requires_grad(self,input:List[Tensor])->bool:fortensorininput:iftensor.requires_grad:returnTruereturnFalse@torch.jit.unused# noqa: T484defcall_checkpoint_bottleneck(self,input:List[Tensor])->Tensor:defclosure(*inputs):returnself.bn_function(inputs)returncp.checkpoint(closure,*input,use_reentrant=False)@torch.jit._overload_method# noqa: F811defforward(self,input:List[Tensor])->Tensor:# noqa: F811pass@torch.jit._overload_method# noqa: F811defforward(self,input:Tensor)->Tensor:# noqa: F811pass# torchscript does not yet support *args, so we overload method# allowing it to take either a List[Tensor] or single Tensordefforward(self,input:Tensor)->Tensor:# noqa: F811ifisinstance(input,Tensor):prev_features=[input]else:prev_features=inputifself.memory_efficientandself.any_requires_grad(prev_features):iftorch.jit.is_scripting():raiseException("Memory Efficient not supported in JIT")bottleneck_output=self.call_checkpoint_bottleneck(prev_features)else:bottleneck_output=self.bn_function(prev_features)new_features=self.conv2(self.relu2(self.norm2(bottleneck_output)))ifself.drop_rate>0:new_features=F.dropout(new_features,p=self.drop_rate,training=self.training)returnnew_featuresclass_DenseBlock(nn.ModuleDict):_version=2def__init__(self,num_layers:int,num_input_features:int,bn_size:int,growth_rate:int,drop_rate:float,memory_efficient:bool=False,)->None:super().__init__()foriinrange(num_layers):layer=_DenseLayer(num_input_features+i*growth_rate,growth_rate=growth_rate,bn_size=bn_size,drop_rate=drop_rate,memory_efficient=memory_efficient,)self.add_module("denselayer%d"%(i+1),layer)defforward(self,init_features:Tensor)->Tensor:features=[init_features]forname,layerinself.items():new_features=layer(features)features.append(new_features)returntorch.cat(features,1)class_Transition(nn.Sequential):def__init__(self,num_input_features:int,num_output_features:int)->None:super().__init__()self.norm=nn.BatchNorm2d(num_input_features)self.relu=nn.ReLU(inplace=True)self.conv=nn.Conv2d(num_input_features,num_output_features,kernel_size=1,stride=1,bias=False)self.pool=nn.AvgPool2d(kernel_size=2,stride=2)classDenseNet(nn.Module):r"""Densenet-BC model class, based on `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_. Args: growth_rate (int) - how many filters to add each layer (`k` in paper) block_config (list of 4 ints) - how many layers in each pooling block num_init_features (int) - the number of filters to learn in the first convolution layer bn_size (int) - multiplicative factor for number of bottle neck layers (i.e. bn_size * k features in the bottleneck layer) drop_rate (float) - dropout rate after each dense layer num_classes (int) - number of classification classes memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_. """def__init__(self,growth_rate:int=32,block_config:Tuple[int,int,int,int]=(6,12,24,16),num_init_features:int=64,bn_size:int=4,drop_rate:float=0,num_classes:int=1000,memory_efficient:bool=False,)->None:super().__init__()_log_api_usage_once(self)# First convolutionself.features=nn.Sequential(OrderedDict([("conv0",nn.Conv2d(3,num_init_features,kernel_size=7,stride=2,padding=3,bias=False)),("norm0",nn.BatchNorm2d(num_init_features)),("relu0",nn.ReLU(inplace=True)),("pool0",nn.MaxPool2d(kernel_size=3,stride=2,padding=1)),]))# Each denseblocknum_features=num_init_featuresfori,num_layersinenumerate(block_config):block=_DenseBlock(num_layers=num_layers,num_input_features=num_features,bn_size=bn_size,growth_rate=growth_rate,drop_rate=drop_rate,memory_efficient=memory_efficient,)self.features.add_module("denseblock%d"%(i+1),block)num_features=num_features+num_layers*growth_rateifi!=len(block_config)-1:trans=_Transition(num_input_features=num_features,num_output_features=num_features//2)self.features.add_module("transition%d"%(i+1),trans)num_features=num_features//2# Final batch normself.features.add_module("norm5",nn.BatchNorm2d(num_features))# Linear layerself.classifier=nn.Linear(num_features,num_classes)# Official init from torch repo.forminself.modules():ifisinstance(m,nn.Conv2d):nn.init.kaiming_normal_(m.weight)elifisinstance(m,nn.BatchNorm2d):nn.init.constant_(m.weight,1)nn.init.constant_(m.bias,0)elifisinstance(m,nn.Linear):nn.init.constant_(m.bias,0)defforward(self,x:Tensor)->Tensor:features=self.features(x)out=F.relu(features,inplace=True)out=F.adaptive_avg_pool2d(out,(1,1))out=torch.flatten(out,1)out=self.classifier(out)returnoutdef_load_state_dict(model:nn.Module,weights:WeightsEnum,progress:bool)->None:# '.'s are no longer allowed in module names, but previous _DenseLayer# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.# They are also in the checkpoints in model_urls. This pattern is used# to find such keys.pattern=re.compile(r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$")state_dict=weights.get_state_dict(progress=progress,check_hash=True)forkeyinlist(state_dict.keys()):res=pattern.match(key)ifres:new_key=res.group(1)+res.group(2)state_dict[new_key]=state_dict[key]delstate_dict[key]model.load_state_dict(state_dict)def_densenet(growth_rate:int,block_config:Tuple[int,int,int,int],num_init_features:int,weights:Optional[WeightsEnum],progress:bool,**kwargs:Any,)->DenseNet:ifweightsisnotNone:_ovewrite_named_param(kwargs,"num_classes",len(weights.meta["categories"]))model=DenseNet(growth_rate,block_config,num_init_features,**kwargs)ifweightsisnotNone:_load_state_dict(model=model,weights=weights,progress=progress)returnmodel_COMMON_META={"min_size":(29,29),"categories":_IMAGENET_CATEGORIES,"recipe":"https://github.com/pytorch/vision/pull/116","_docs":"""These weights are ported from LuaTorch.""",}
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",DenseNet121_Weights.IMAGENET1K_V1))defdensenet121(*,weights:Optional[DenseNet121_Weights]=None,progress:bool=True,**kwargs:Any)->DenseNet:r"""Densenet-121 model from `Densely Connected Convolutional Networks <https://arxiv.org/abs/1608.06993>`_. Args: weights (:class:`~torchvision.models.DenseNet121_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.DenseNet121_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.densenet.DenseNet`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py>`_ for more details about this class. .. autoclass:: torchvision.models.DenseNet121_Weights :members: """weights=DenseNet121_Weights.verify(weights)return_densenet(32,(6,12,24,16),64,weights,progress,**kwargs)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",DenseNet161_Weights.IMAGENET1K_V1))defdensenet161(*,weights:Optional[DenseNet161_Weights]=None,progress:bool=True,**kwargs:Any)->DenseNet:r"""Densenet-161 model from `Densely Connected Convolutional Networks <https://arxiv.org/abs/1608.06993>`_. Args: weights (:class:`~torchvision.models.DenseNet161_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.DenseNet161_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.densenet.DenseNet`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py>`_ for more details about this class. .. autoclass:: torchvision.models.DenseNet161_Weights :members: """weights=DenseNet161_Weights.verify(weights)return_densenet(48,(6,12,36,24),96,weights,progress,**kwargs)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",DenseNet169_Weights.IMAGENET1K_V1))defdensenet169(*,weights:Optional[DenseNet169_Weights]=None,progress:bool=True,**kwargs:Any)->DenseNet:r"""Densenet-169 model from `Densely Connected Convolutional Networks <https://arxiv.org/abs/1608.06993>`_. Args: weights (:class:`~torchvision.models.DenseNet169_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.DenseNet169_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.densenet.DenseNet`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py>`_ for more details about this class. .. autoclass:: torchvision.models.DenseNet169_Weights :members: """weights=DenseNet169_Weights.verify(weights)return_densenet(32,(6,12,32,32),64,weights,progress,**kwargs)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",DenseNet201_Weights.IMAGENET1K_V1))defdensenet201(*,weights:Optional[DenseNet201_Weights]=None,progress:bool=True,**kwargs:Any)->DenseNet:r"""Densenet-201 model from `Densely Connected Convolutional Networks <https://arxiv.org/abs/1608.06993>`_. Args: weights (:class:`~torchvision.models.DenseNet201_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.DenseNet201_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.densenet.DenseNet`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py>`_ for more details about this class. .. autoclass:: torchvision.models.DenseNet201_Weights :members: """weights=DenseNet201_Weights.verify(weights)return_densenet(32,(6,12,48,32),64,weights,progress,**kwargs)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.