Shortcuts

Source code for torchvision.models.alexnet

from functools import partial
from typing import Any, Optional

import torch
import torch.nn as nn

from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once
from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param, handle_legacy_interface


__all__ = ["AlexNet", "AlexNet_Weights", "alexnet"]


class AlexNet(nn.Module):
    def __init__(self, num_classes: int = 1000, dropout: float = 0.5) -> None:
        super().__init__()
        _log_api_usage_once(self)
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


[docs]class AlexNet_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", transforms=partial(ImageClassification, crop_size=224), meta={ "num_params": 61100840, "min_size": (63, 63), "categories": _IMAGENET_CATEGORIES, "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg", "_metrics": { "ImageNet-1K": { "acc@1": 56.522, "acc@5": 79.066, } }, "_docs": """ These weights reproduce closely the results of the paper using a simplified training recipe. """, }, ) DEFAULT = IMAGENET1K_V1
[docs]@register_model() @handle_legacy_interface(weights=("pretrained", AlexNet_Weights.IMAGENET1K_V1)) def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet: """AlexNet model architecture from `One weird trick for parallelizing convolutional neural networks <https://arxiv.org/abs/1404.5997>`__. .. note:: AlexNet was originally introduced in the `ImageNet Classification with Deep Convolutional Neural Networks <https://papers.nips.cc/paper/2012/hash/c399862d3b9d6b76c8436e924a68c45b-Abstract.html>`__ paper. Our implementation is based instead on the "One weird trick" paper above. Args: weights (:class:`~torchvision.models.AlexNet_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.AlexNet_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.squeezenet.AlexNet`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/alexnet.py>`_ for more details about this class. .. autoclass:: torchvision.models.AlexNet_Weights :members: """ weights = AlexNet_Weights.verify(weights) if weights is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) model = AlexNet(**kwargs) if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) return model
# The dictionary below is internal implementation detail and will be removed in v0.15 from ._utils import _ModelURLs model_urls = _ModelURLs( { "alexnet": AlexNet_Weights.IMAGENET1K_V1.url, } )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources