• Docs >
  • Models and pre-trained weights
Shortcuts

Models and pre-trained weights

The torchvision.models subpackage contains definitions of models for addressing different tasks, including: image classification, pixelwise semantic segmentation, object detection, instance segmentation, person keypoint detection, video classification, and optical flow.

Note

Backward compatibility is guaranteed for loading a serialized state_dict to the model created using old PyTorch version. On the contrary, loading entire saved models or serialized ScriptModules (seralized using older versions of PyTorch) may not preserve the historic behaviour. Refer to the following documentation

Classification

The models subpackage contains definitions for the following model architectures for image classification:

You can construct a model with random weights by calling its constructor:

import torchvision.models as models
resnet18 = models.resnet18()
alexnet = models.alexnet()
vgg16 = models.vgg16()
squeezenet = models.squeezenet1_0()
densenet = models.densenet161()
inception = models.inception_v3()
googlenet = models.googlenet()
shufflenet = models.shufflenet_v2_x1_0()
mobilenet_v2 = models.mobilenet_v2()
mobilenet_v3_large = models.mobilenet_v3_large()
mobilenet_v3_small = models.mobilenet_v3_small()
resnext50_32x4d = models.resnext50_32x4d()
wide_resnet50_2 = models.wide_resnet50_2()
mnasnet = models.mnasnet1_0()
efficientnet_b0 = models.efficientnet_b0()
efficientnet_b1 = models.efficientnet_b1()
efficientnet_b2 = models.efficientnet_b2()
efficientnet_b3 = models.efficientnet_b3()
efficientnet_b4 = models.efficientnet_b4()
efficientnet_b5 = models.efficientnet_b5()
efficientnet_b6 = models.efficientnet_b6()
efficientnet_b7 = models.efficientnet_b7()
regnet_y_400mf = models.regnet_y_400mf()
regnet_y_800mf = models.regnet_y_800mf()
regnet_y_1_6gf = models.regnet_y_1_6gf()
regnet_y_3_2gf = models.regnet_y_3_2gf()
regnet_y_8gf = models.regnet_y_8gf()
regnet_y_16gf = models.regnet_y_16gf()
regnet_y_32gf = models.regnet_y_32gf()
regnet_y_128gf = models.regnet_y_128gf()
regnet_x_400mf = models.regnet_x_400mf()
regnet_x_800mf = models.regnet_x_800mf()
regnet_x_1_6gf = models.regnet_x_1_6gf()
regnet_x_3_2gf = models.regnet_x_3_2gf()
regnet_x_8gf = models.regnet_x_8gf()
regnet_x_16gf = models.regnet_x_16gf()
regnet_x_32gf = models.regnet_x_32gf()
vit_b_16 = models.vit_b_16()
vit_b_32 = models.vit_b_32()
vit_l_16 = models.vit_l_16()
vit_l_32 = models.vit_l_32()
convnext_tiny = models.convnext_tiny()
convnext_small = models.convnext_small()
convnext_base = models.convnext_base()
convnext_large = models.convnext_large()

We provide pre-trained models, using the PyTorch torch.utils.model_zoo. These can be constructed by passing pretrained=True:

import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
densenet = models.densenet161(pretrained=True)
inception = models.inception_v3(pretrained=True)
googlenet = models.googlenet(pretrained=True)
shufflenet = models.shufflenet_v2_x1_0(pretrained=True)
mobilenet_v2 = models.mobilenet_v2(pretrained=True)
mobilenet_v3_large = models.mobilenet_v3_large(pretrained=True)
mobilenet_v3_small = models.mobilenet_v3_small(pretrained=True)
resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)
mnasnet = models.mnasnet1_0(pretrained=True)
efficientnet_b0 = models.efficientnet_b0(pretrained=True)
efficientnet_b1 = models.efficientnet_b1(pretrained=True)
efficientnet_b2 = models.efficientnet_b2(pretrained=True)
efficientnet_b3 = models.efficientnet_b3(pretrained=True)
efficientnet_b4 = models.efficientnet_b4(pretrained=True)
efficientnet_b5 = models.efficientnet_b5(pretrained=True)
efficientnet_b6 = models.efficientnet_b6(pretrained=True)
efficientnet_b7 = models.efficientnet_b7(pretrained=True)
regnet_y_400mf = models.regnet_y_400mf(pretrained=True)
regnet_y_800mf = models.regnet_y_800mf(pretrained=True)
regnet_y_1_6gf = models.regnet_y_1_6gf(pretrained=True)
regnet_y_3_2gf = models.regnet_y_3_2gf(pretrained=True)
regnet_y_8gf = models.regnet_y_8gf(pretrained=True)
regnet_y_16gf = models.regnet_y_16gf(pretrained=True)
regnet_y_32gf = models.regnet_y_32gf(pretrained=True)
regnet_x_400mf = models.regnet_x_400mf(pretrained=True)
regnet_x_800mf = models.regnet_x_800mf(pretrained=True)
regnet_x_1_6gf = models.regnet_x_1_6gf(pretrained=True)
regnet_x_3_2gf = models.regnet_x_3_2gf(pretrained=True)
regnet_x_8gf = models.regnet_x_8gf(pretrained=True)
regnet_x_16gf = models.regnet_x_16gf(pretrainedTrue)
regnet_x_32gf = models.regnet_x_32gf(pretrained=True)
vit_b_16 = models.vit_b_16(pretrained=True)
vit_b_32 = models.vit_b_32(pretrained=True)
vit_l_16 = models.vit_l_16(pretrained=True)
vit_l_32 = models.vit_l_32(pretrained=True)
convnext_tiny = models.convnext_tiny(pretrained=True)
convnext_small = models.convnext_small(pretrained=True)
convnext_base = models.convnext_base(pretrained=True)
convnext_large = models.convnext_large(pretrained=True)

Instancing a pre-trained model will download its weights to a cache directory. This directory can be set using the TORCH_HOME environment variable. See torch.hub.load_state_dict_from_url() for details.

Some models use modules which have different training and evaluation behavior, such as batch normalization. To switch between these modes, use model.train() or model.eval() as appropriate. See train() or eval() for details.

All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]. You can use the following transform to normalize:

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

An example of such normalization can be found in the imagenet example here

The process for obtaining the values of mean and std is roughly equivalent to:

import torch
from torchvision import datasets, transforms as T

transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()])
dataset = datasets.ImageNet(".", split="train", transform=transform)

means = []
stds = []
for img in subset(dataset):
    means.append(torch.mean(img))
    stds.append(torch.std(img))

mean = torch.mean(torch.tensor(means))
std = torch.mean(torch.tensor(stds))

Unfortunately, the concrete subset that was used is lost. For more information see this discussion or these experiments.

The sizes of the EfficientNet models depend on the variant. For the exact input sizes check here

ImageNet 1-crop error rates

Model

Acc@1

Acc@5

AlexNet

56.522

79.066

VGG-11

69.020

88.628

VGG-13

69.928

89.246

VGG-16

71.592

90.382

VGG-19

72.376

90.876

VGG-11 with batch normalization

70.370

89.810

VGG-13 with batch normalization

71.586

90.374

VGG-16 with batch normalization

73.360

91.516

VGG-19 with batch normalization

74.218

91.842

ResNet-18

69.758

89.078

ResNet-34

73.314

91.420

ResNet-50

76.130

92.862

ResNet-101

77.374

93.546

ResNet-152

78.312

94.046

SqueezeNet 1.0

58.092

80.420

SqueezeNet 1.1

58.178

80.624

Densenet-121

74.434

91.972

Densenet-169

75.600

92.806

Densenet-201

76.896

93.370

Densenet-161

77.138

93.560

Inception v3

77.294

93.450

GoogleNet

69.778

89.530

ShuffleNet V2 x1.0

69.362

88.316

ShuffleNet V2 x0.5

60.552

81.746

MobileNet V2

71.878

90.286

MobileNet V3 Large

74.042

91.340

MobileNet V3 Small

67.668

87.402

ResNeXt-50-32x4d

77.618

93.698

ResNeXt-101-32x8d

79.312

94.526

Wide ResNet-50-2

78.468

94.086

Wide ResNet-101-2

78.848

94.284

MNASNet 1.0

73.456

91.510

MNASNet 0.5

67.734

87.490

EfficientNet-B0

77.692

93.532

EfficientNet-B1

78.642

94.186

EfficientNet-B2

80.608

95.310

EfficientNet-B3

82.008

96.054

EfficientNet-B4

83.384

96.594

EfficientNet-B5

83.444

96.628

EfficientNet-B6

84.008

96.916

EfficientNet-B7

84.122

96.908

regnet_x_400mf

72.834

90.950

regnet_x_800mf

75.212

92.348

regnet_x_1_6gf

77.040

93.440

regnet_x_3_2gf

78.364

93.992

regnet_x_8gf

79.344

94.686

regnet_x_16gf

80.058

94.944

regnet_x_32gf

80.622

95.248

regnet_y_400mf

74.046

91.716

regnet_y_800mf

76.420

93.136

regnet_y_1_6gf

77.950

93.966

regnet_y_3_2gf

78.948

94.576

regnet_y_8gf

80.032

95.048

regnet_y_16gf

80.424

95.240

regnet_y_32gf

80.878

95.340

vit_b_16

81.072

95.318

vit_b_32

75.912

92.466

vit_l_16

79.662

94.638

vit_l_32

76.972

93.070

convnext_tiny

82.520

96.146

convnext_small

83.616

96.650

convnext_base

84.062

96.870

convnext_large

84.414

96.976

Alexnet

alexnet([pretrained, progress])

AlexNet model architecture from the “One weird trick…” paper.

VGG

vgg11([pretrained, progress])

VGG 11-layer model (configuration “A”) from “Very Deep Convolutional Networks For Large-Scale Image Recognition”.

vgg11_bn([pretrained, progress])

VGG 11-layer model (configuration “A”) with batch normalization “Very Deep Convolutional Networks For Large-Scale Image Recognition”.

vgg13([pretrained, progress])

VGG 13-layer model (configuration “B”) “Very Deep Convolutional Networks For Large-Scale Image Recognition”.

vgg13_bn([pretrained, progress])

VGG 13-layer model (configuration “B”) with batch normalization “Very Deep Convolutional Networks For Large-Scale Image Recognition”.

vgg16([pretrained, progress])

VGG 16-layer model (configuration “D”) “Very Deep Convolutional Networks For Large-Scale Image Recognition”.

vgg16_bn([pretrained, progress])

VGG 16-layer model (configuration “D”) with batch normalization “Very Deep Convolutional Networks For Large-Scale Image Recognition”.

vgg19([pretrained, progress])

VGG 19-layer model (configuration “E”) “Very Deep Convolutional Networks For Large-Scale Image Recognition”.

vgg19_bn([pretrained, progress])

VGG 19-layer model (configuration ‘E’) with batch normalization “Very Deep Convolutional Networks For Large-Scale Image Recognition”.

ResNet

resnet18([pretrained, progress])

ResNet-18 model from “Deep Residual Learning for Image Recognition”.

resnet34([pretrained, progress])

ResNet-34 model from “Deep Residual Learning for Image Recognition”.

resnet50([pretrained, progress])

ResNet-50 model from “Deep Residual Learning for Image Recognition”.

resnet101([pretrained, progress])

ResNet-101 model from “Deep Residual Learning for Image Recognition”.

resnet152([pretrained, progress])

ResNet-152 model from “Deep Residual Learning for Image Recognition”.

SqueezeNet

squeezenet1_0([pretrained, progress])

SqueezeNet model architecture from the “SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and <0.5MB model size” paper.

squeezenet1_1([pretrained, progress])

SqueezeNet 1.1 model from the official SqueezeNet repo.

DenseNet

densenet121([pretrained, progress])

Densenet-121 model from “Densely Connected Convolutional Networks”.

densenet169([pretrained, progress])

Densenet-169 model from “Densely Connected Convolutional Networks”.

densenet161([pretrained, progress])

Densenet-161 model from “Densely Connected Convolutional Networks”.

densenet201([pretrained, progress])

Densenet-201 model from “Densely Connected Convolutional Networks”.

Inception v3

inception_v3([pretrained, progress])

Inception v3 model architecture from “Rethinking the Inception Architecture for Computer Vision”.

GoogLeNet

googlenet([pretrained, progress])

GoogLeNet (Inception v1) model architecture from “Going Deeper with Convolutions”.

ShuffleNet v2

shufflenet_v2_x0_5([pretrained, progress])

Constructs a ShuffleNetV2 with 0.5x output channels, as described in “ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design”.

shufflenet_v2_x1_0([pretrained, progress])

Constructs a ShuffleNetV2 with 1.0x output channels, as described in “ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design”.

shufflenet_v2_x1_5([pretrained, progress])

Constructs a ShuffleNetV2 with 1.5x output channels, as described in “ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design”.

shufflenet_v2_x2_0([pretrained, progress])

Constructs a ShuffleNetV2 with 2.0x output channels, as described in “ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design”.

MobileNet v2

mobilenet_v2([pretrained, progress])

Constructs a MobileNetV2 architecture from “MobileNetV2: Inverted Residuals and Linear Bottlenecks”.

MobileNet v3

mobilenet_v3_large([pretrained, progress])

Constructs a large MobileNetV3 architecture from “Searching for MobileNetV3”.

mobilenet_v3_small([pretrained, progress])

Constructs a small MobileNetV3 architecture from “Searching for MobileNetV3”.

ResNext

resnext50_32x4d([pretrained, progress])

ResNeXt-50 32x4d model from “Aggregated Residual Transformation for Deep Neural Networks”.

resnext101_32x8d([pretrained, progress])

ResNeXt-101 32x8d model from “Aggregated Residual Transformation for Deep Neural Networks”.

Wide ResNet

wide_resnet50_2([pretrained, progress])

Wide ResNet-50-2 model from “Wide Residual Networks”.

wide_resnet101_2([pretrained, progress])

Wide ResNet-101-2 model from “Wide Residual Networks”.

MNASNet

mnasnet0_5([pretrained, progress])

MNASNet with depth multiplier of 0.5 from “MnasNet: Platform-Aware Neural Architecture Search for Mobile”.

mnasnet0_75([pretrained, progress])

MNASNet with depth multiplier of 0.75 from “MnasNet: Platform-Aware Neural Architecture Search for Mobile”.

mnasnet1_0([pretrained, progress])

MNASNet with depth multiplier of 1.0 from “MnasNet: Platform-Aware Neural Architecture Search for Mobile”.

mnasnet1_3([pretrained, progress])

MNASNet with depth multiplier of 1.3 from “MnasNet: Platform-Aware Neural Architecture Search for Mobile”.

EfficientNet

efficientnet_b0([pretrained, progress])

Constructs a EfficientNet B0 architecture from “EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks”.

efficientnet_b1([pretrained, progress])

Constructs a EfficientNet B1 architecture from “EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks”.

efficientnet_b2([pretrained, progress])

Constructs a EfficientNet B2 architecture from “EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks”.

efficientnet_b3([pretrained, progress])

Constructs a EfficientNet B3 architecture from “EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks”.

efficientnet_b4([pretrained, progress])

Constructs a EfficientNet B4 architecture from “EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks”.

efficientnet_b5([pretrained, progress])

Constructs a EfficientNet B5 architecture from “EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks”.

efficientnet_b6([pretrained, progress])

Constructs a EfficientNet B6 architecture from “EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks”.

efficientnet_b7([pretrained, progress])

Constructs a EfficientNet B7 architecture from “EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks”.

RegNet

regnet_y_400mf([pretrained, progress])

Constructs a RegNetY_400MF architecture from “Designing Network Design Spaces”.

regnet_y_800mf([pretrained, progress])

Constructs a RegNetY_800MF architecture from “Designing Network Design Spaces”.

regnet_y_1_6gf([pretrained, progress])

Constructs a RegNetY_1.6GF architecture from “Designing Network Design Spaces”.

regnet_y_3_2gf([pretrained, progress])

Constructs a RegNetY_3.2GF architecture from “Designing Network Design Spaces”.

regnet_y_8gf([pretrained, progress])

Constructs a RegNetY_8GF architecture from “Designing Network Design Spaces”.

regnet_y_16gf([pretrained, progress])

Constructs a RegNetY_16GF architecture from “Designing Network Design Spaces”.

regnet_y_32gf([pretrained, progress])

Constructs a RegNetY_32GF architecture from “Designing Network Design Spaces”.

regnet_y_128gf([pretrained, progress])

Constructs a RegNetY_128GF architecture from “Designing Network Design Spaces”.

regnet_x_400mf([pretrained, progress])

Constructs a RegNetX_400MF architecture from “Designing Network Design Spaces”.

regnet_x_800mf([pretrained, progress])

Constructs a RegNetX_800MF architecture from “Designing Network Design Spaces”.

regnet_x_1_6gf([pretrained, progress])

Constructs a RegNetX_1.6GF architecture from “Designing Network Design Spaces”.

regnet_x_3_2gf([pretrained, progress])

Constructs a RegNetX_3.2GF architecture from “Designing Network Design Spaces”.

regnet_x_8gf([pretrained, progress])

Constructs a RegNetX_8GF architecture from “Designing Network Design Spaces”.

regnet_x_16gf([pretrained, progress])

Constructs a RegNetX_16GF architecture from “Designing Network Design Spaces”.

regnet_x_32gf([pretrained, progress])

Constructs a RegNetX_32GF architecture from “Designing Network Design Spaces”.

VisionTransformer

vit_b_16([pretrained, progress])

Constructs a vit_b_16 architecture from “An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale”.

vit_b_32([pretrained, progress])

Constructs a vit_b_32 architecture from “An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale”.

vit_l_16([pretrained, progress])

Constructs a vit_l_16 architecture from “An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale”.

vit_l_32([pretrained, progress])

Constructs a vit_l_32 architecture from “An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale”.

ConvNeXt

convnext_tiny(*[, pretrained, progress])

ConvNeXt Tiny model architecture from the “A ConvNet for the 2020s” paper.

convnext_small(*[, pretrained, progress])

ConvNeXt Small model architecture from the “A ConvNet for the 2020s” paper.

convnext_base(*[, pretrained, progress])

ConvNeXt Base model architecture from the “A ConvNet for the 2020s” paper.

convnext_large(*[, pretrained, progress])

ConvNeXt Large model architecture from the “A ConvNet for the 2020s” paper.

Quantized Models

The following architectures provide support for INT8 quantized models. You can get a model with random weights by calling its constructor:

import torchvision.models as models
googlenet = models.quantization.googlenet()
inception_v3 = models.quantization.inception_v3()
mobilenet_v2 = models.quantization.mobilenet_v2()
mobilenet_v3_large = models.quantization.mobilenet_v3_large()
resnet18 = models.quantization.resnet18()
resnet50 = models.quantization.resnet50()
resnext101_32x8d = models.quantization.resnext101_32x8d()
shufflenet_v2_x0_5 = models.quantization.shufflenet_v2_x0_5()
shufflenet_v2_x1_0 = models.quantization.shufflenet_v2_x1_0()
shufflenet_v2_x1_5 = models.quantization.shufflenet_v2_x1_5()
shufflenet_v2_x2_0 = models.quantization.shufflenet_v2_x2_0()

Obtaining a pre-trained quantized model can be done with a few lines of code:

import torchvision.models as models
model = models.quantization.mobilenet_v2(pretrained=True, quantize=True)
model.eval()
# run the model with quantized inputs and weights
out = model(torch.rand(1, 3, 224, 224))

We provide pre-trained quantized weights for the following models:

Model

Acc@1

Acc@5

MobileNet V2

71.658

90.150

MobileNet V3 Large

73.004

90.858

ShuffleNet V2 x1.0

68.360

87.582

ShuffleNet V2 x0.5

57.972

79.780

ResNet 18

69.494

88.882

ResNet 50

75.920

92.814

ResNext 101 32x8d

78.986

94.480

Inception V3

77.176

93.354

GoogleNet

69.826

89.404

Semantic Segmentation

The models subpackage contains definitions for the following model architectures for semantic segmentation:

As with image classification models, all pre-trained models expect input images normalized in the same way. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]. They have been trained on images resized such that their minimum size is 520.

For details on how to plot the masks of such models, you may refer to Semantic segmentation models.

The pre-trained models have been trained on a subset of COCO train2017, on the 20 categories that are present in the Pascal VOC dataset. You can see more information on how the subset has been selected in references/segmentation/coco_utils.py. The classes that the pre-trained model outputs are the following, in order:

['__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']

The accuracies of the pre-trained models evaluated on COCO val2017 are as follows

Network

mean IoU

global pixelwise acc

FCN ResNet50

60.5

91.4

FCN ResNet101

63.7

91.9

DeepLabV3 ResNet50

66.4

92.4

DeepLabV3 ResNet101

67.4

92.4

DeepLabV3 MobileNetV3-Large

60.3

91.2

LR-ASPP MobileNetV3-Large

57.9

91.2

Fully Convolutional Networks

torchvision.models.segmentation.fcn_resnet50([…])

Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.

torchvision.models.segmentation.fcn_resnet101([…])

Constructs a Fully-Convolutional Network model with a ResNet-101 backbone.

DeepLabV3

torchvision.models.segmentation.deeplabv3_resnet50([…])

Constructs a DeepLabV3 model with a ResNet-50 backbone.

torchvision.models.segmentation.deeplabv3_resnet101([…])

Constructs a DeepLabV3 model with a ResNet-101 backbone.

torchvision.models.segmentation.deeplabv3_mobilenet_v3_large([…])

Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.

LR-ASPP

torchvision.models.segmentation.lraspp_mobilenet_v3_large([…])

Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone.

Object Detection, Instance Segmentation and Person Keypoint Detection

The models subpackage contains definitions for the following model architectures for detection:

The pre-trained models for detection, instance segmentation and keypoint detection are initialized with the classification models in torchvision.

The models expect a list of Tensor[C, H, W], in the range 0-1. The models internally resize the images but the behaviour varies depending on the model. Check the constructor of the models for more information. The output format of such models is illustrated in Instance segmentation models.

For object detection and instance segmentation, the pre-trained models return the predictions of the following classes:

COCO_INSTANCE_CATEGORY_NAMES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

Here are the summary of the accuracies for the models trained on the instances set of COCO train2017 and evaluated on COCO val2017.

Network

box AP

mask AP

keypoint AP

Faster R-CNN ResNet-50 FPN

37.0

Faster R-CNN MobileNetV3-Large FPN

32.8

Faster R-CNN MobileNetV3-Large 320 FPN

22.8

FCOS ResNet-50 FPN

39.2

RetinaNet ResNet-50 FPN

36.4

SSD300 VGG16

25.1

SSDlite320 MobileNetV3-Large

21.3

Mask R-CNN ResNet-50 FPN

37.9

34.6

For person keypoint detection, the accuracies for the pre-trained models are as follows

Network

box AP

mask AP

keypoint AP

Keypoint R-CNN ResNet-50 FPN

54.6

65.0

For person keypoint detection, the pre-trained model return the keypoints in the following order:

COCO_PERSON_KEYPOINT_NAMES = [
    'nose',
    'left_eye',
    'right_eye',
    'left_ear',
    'right_ear',
    'left_shoulder',
    'right_shoulder',
    'left_elbow',
    'right_elbow',
    'left_wrist',
    'right_wrist',
    'left_hip',
    'right_hip',
    'left_knee',
    'right_knee',
    'left_ankle',
    'right_ankle'
]

Runtime characteristics

The implementations of the models for object detection, instance segmentation and keypoint detection are efficient.

In the following table, we use 8 GPUs to report the results. During training, we use a batch size of 2 per GPU for all models except SSD which uses 4 and SSDlite which uses 24. During testing a batch size of 1 is used.

For test time, we report the time for the model evaluation and postprocessing (including mask pasting in image), but not the time for computing the precision-recall.

Network

train time (s / it)

test time (s / it)

memory (GB)

Faster R-CNN ResNet-50 FPN

0.2288

0.0590

5.2

Faster R-CNN MobileNetV3-Large FPN

0.1020

0.0415

1.0

Faster R-CNN MobileNetV3-Large 320 FPN

0.0978

0.0376

0.6

FCOS ResNet-50 FPN

0.1450

0.0539

3.3

RetinaNet ResNet-50 FPN

0.2514

0.0939

4.1

SSD300 VGG16

0.2093

0.0744

1.5

SSDlite320 MobileNetV3-Large

0.1773

0.0906

1.5

Mask R-CNN ResNet-50 FPN

0.2728

0.0903

5.4

Keypoint R-CNN ResNet-50 FPN

0.3789

0.1242

6.8

Faster R-CNN

torchvision.models.detection.fasterrcnn_resnet50_fpn([…])

Constructs a Faster R-CNN model with a ResNet-50-FPN backbone.

torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn([…])

Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone.

torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn([…])

Constructs a low resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone tunned for mobile use-cases.

FCOS

torchvision.models.detection.fcos_resnet50_fpn([…])

Constructs a FCOS model with a ResNet-50-FPN backbone.

RetinaNet

torchvision.models.detection.retinanet_resnet50_fpn([…])

Constructs a RetinaNet model with a ResNet-50-FPN backbone.

SSD

torchvision.models.detection.ssd300_vgg16([…])

Constructs an SSD model with input size 300x300 and a VGG16 backbone.

SSDlite

torchvision.models.detection.ssdlite320_mobilenet_v3_large([…])

Constructs an SSDlite model with input size 320x320 and a MobileNetV3 Large backbone, as described at “Searching for MobileNetV3” and “MobileNetV2: Inverted Residuals and Linear Bottlenecks”.

Mask R-CNN

torchvision.models.detection.maskrcnn_resnet50_fpn([…])

Constructs a Mask R-CNN model with a ResNet-50-FPN backbone.

Keypoint R-CNN

torchvision.models.detection.keypointrcnn_resnet50_fpn([…])

Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.

Video classification

We provide models for action recognition pre-trained on Kinetics-400. They have all been trained with the scripts provided in references/video_classification.

All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB videos of shape (3 x T x H x W), where H and W are expected to be 112, and T is a number of video frames in a clip. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.43216, 0.394666, 0.37645] and std = [0.22803, 0.22145, 0.216989].

Note

The normalization parameters are different from the image classification ones, and correspond to the mean and std from Kinetics-400.

Note

For now, normalization code can be found in references/video_classification/transforms.py, see the Normalize function there. Note that it differs from standard normalization for images because it assumes the video is 4d.

Kinetics 1-crop accuracies for clip length 16 (16x112x112)

Network

Clip acc@1

Clip acc@5

ResNet 3D 18

52.75

75.45

ResNet MC 18

53.90

76.29

ResNet (2+1)D

57.50

78.81

ResNet 3D

torchvision.models.video.r3d_18([…])

Construct 18 layer Resnet3D model as in https://arxiv.org/abs/1711.11248

ResNet Mixed Convolution

torchvision.models.video.mc3_18([…])

Constructor for 18 layer Mixed Convolution network as in https://arxiv.org/abs/1711.11248

ResNet (2+1)D

torchvision.models.video.r2plus1d_18([…])

Constructor for the 18 layer deep R(2+1)D network as in https://arxiv.org/abs/1711.11248

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources