.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/transforms/plot_cutmix_mixup.py" .. LINE NUMBERS ARE GIVEN BELOW. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_transforms_plot_cutmix_mixup.py: =========================== How to use CutMix and MixUp =========================== .. note:: Try on `collab `_ or :ref:`go to the end ` to download the full example code. :class:`~torchvision.transforms.v2.CutMix` and :class:`~torchvision.transforms.v2.MixUp` are popular augmentation strategies that can improve classification accuracy. These transforms are slightly different from the rest of the Torchvision transforms, because they expect **batches** of samples as input, not individual images. In this example we'll explain how to use them: after the ``DataLoader``, or as part of a collation function. .. GENERATED FROM PYTHON SOURCE LINES 23-30 .. code-block:: Python import torch from torchvision.datasets import FakeData from torchvision.transforms import v2 NUM_CLASSES = 100 .. GENERATED FROM PYTHON SOURCE LINES 31-35 Pre-processing pipeline ----------------------- We'll use a simple but typical image classification pipeline: .. GENERATED FROM PYTHON SOURCE LINES 35-49 .. code-block:: Python preproc = v2.Compose([ v2.PILToTensor(), v2.RandomResizedCrop(size=(224, 224), antialias=True), v2.RandomHorizontalFlip(p=0.5), v2.ToDtype(torch.float32, scale=True), # to float32 in [0, 1] v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), # typically from ImageNet ]) dataset = FakeData(size=1000, num_classes=NUM_CLASSES, transform=preproc) img, label = dataset[0] print(f"{type(img) = }, {img.dtype = }, {img.shape = }, {label = }") .. rst-class:: sphx-glr-script-out .. code-block:: none type(img) = , img.dtype = torch.float32, img.shape = torch.Size([3, 224, 224]), label = 67 .. GENERATED FROM PYTHON SOURCE LINES 50-54 One important thing to note is that neither CutMix nor MixUp are part of this pre-processing pipeline. We'll add them a bit later once we define the DataLoader. Just as a refresher, this is what the DataLoader and training loop would look like if we weren't using CutMix or MixUp: .. GENERATED FROM PYTHON SOURCE LINES 55-65 .. code-block:: Python from torch.utils.data import DataLoader dataloader = DataLoader(dataset, batch_size=4, shuffle=True) for images, labels in dataloader: print(f"{images.shape = }, {labels.shape = }") print(labels.dtype) # break .. rst-class:: sphx-glr-script-out .. code-block:: none images.shape = torch.Size([4, 3, 224, 224]), labels.shape = torch.Size([4]) torch.int64 .. GENERATED FROM PYTHON SOURCE LINES 68-77 Where to use MixUp and CutMix ----------------------------- After the DataLoader ^^^^^^^^^^^^^^^^^^^^ Now let's add CutMix and MixUp. The simplest way to do this right after the DataLoader: the Dataloader has already batched the images and labels for us, and this is exactly what these transforms expect as input: .. GENERATED FROM PYTHON SOURCE LINES 77-91 .. code-block:: Python dataloader = DataLoader(dataset, batch_size=4, shuffle=True) cutmix = v2.CutMix(num_classes=NUM_CLASSES) mixup = v2.MixUp(num_classes=NUM_CLASSES) cutmix_or_mixup = v2.RandomChoice([cutmix, mixup]) for images, labels in dataloader: print(f"Before CutMix/MixUp: {images.shape = }, {labels.shape = }") images, labels = cutmix_or_mixup(images, labels) print(f"After CutMix/MixUp: {images.shape = }, {labels.shape = }") # break .. rst-class:: sphx-glr-script-out .. code-block:: none Before CutMix/MixUp: images.shape = torch.Size([4, 3, 224, 224]), labels.shape = torch.Size([4]) After CutMix/MixUp: images.shape = torch.Size([4, 3, 224, 224]), labels.shape = torch.Size([4, 100]) .. GENERATED FROM PYTHON SOURCE LINES 92-106 Note how the labels were also transformed: we went from a batched label of shape (batch_size,) to a tensor of shape (batch_size, num_classes). The transformed labels can still be passed as-is to a loss function like :func:`torch.nn.functional.cross_entropy`. As part of the collation function ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Passing the transforms after the DataLoader is the simplest way to use CutMix and MixUp, but one disadvantage is that it does not take advantage of the DataLoader multi-processing. For that, we can pass those transforms as part of the collation function (refer to the `PyTorch docs `_ to learn more about collation). .. GENERATED FROM PYTHON SOURCE LINES 107-123 .. code-block:: Python from torch.utils.data import default_collate def collate_fn(batch): return cutmix_or_mixup(*default_collate(batch)) dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2, collate_fn=collate_fn) for images, labels in dataloader: print(f"{images.shape = }, {labels.shape = }") # No need to call cutmix_or_mixup, it's already been called as part of the DataLoader! # break .. rst-class:: sphx-glr-script-out .. code-block:: none images.shape = torch.Size([4, 3, 224, 224]), labels.shape = torch.Size([4, 100]) .. GENERATED FROM PYTHON SOURCE LINES 124-135 Non-standard input format ------------------------- So far we've used a typical sample structure where we pass ``(images, labels)`` as inputs. MixUp and CutMix will magically work by default with most common sample structures: tuples where the second parameter is a tensor label, or dict with a "label[s]" key. Look at the documentation of the ``labels_getter`` parameter for more details. If your samples have a different structure, you can still use CutMix and MixUp by passing a callable to the ``labels_getter`` parameter. For example: .. GENERATED FROM PYTHON SOURCE LINES 135-151 .. code-block:: Python batch = { "imgs": torch.rand(4, 3, 224, 224), "target": { "classes": torch.randint(0, NUM_CLASSES, size=(4,)), "some_other_key": "this is going to be passed-through" } } def labels_getter(batch): return batch["target"]["classes"] out = v2.CutMix(num_classes=NUM_CLASSES, labels_getter=labels_getter)(batch) print(f"{out['imgs'].shape = }, {out['target']['classes'].shape = }") .. rst-class:: sphx-glr-script-out .. code-block:: none out['imgs'].shape = torch.Size([4, 3, 224, 224]), out['target']['classes'].shape = torch.Size([4, 100]) .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.169 seconds) .. _sphx_glr_download_auto_examples_transforms_plot_cutmix_mixup.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_cutmix_mixup.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_cutmix_mixup.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_