Shortcuts

Illustration of transforms

This example illustrates the various transforms available in the torchvision.transforms module.

# sphinx_gallery_thumbnail_path = "../../gallery/assets/transforms_thumbnail.png"

from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np

import torch
import torchvision.transforms as T


plt.rcParams["savefig.bbox"] = 'tight'
orig_img = Image.open(Path('assets') / 'astronaut.jpg')
# if you change the seed, make sure that the randomly-applied transforms
# properly show that the image can be both transformed and *not* transformed!
torch.manual_seed(0)


def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Make a 2d grid even if there's just 1 row
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0]) + with_orig
    fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        row = [orig_img] + row if with_orig else row
        for col_idx, img in enumerate(row):
            ax = axs[row_idx, col_idx]
            ax.imshow(np.asarray(img), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if with_orig:
        axs[0, 0].set(title='Original image')
        axs[0, 0].title.set_size(8)
    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()

Pad

The Pad transform (see also pad()) fills image borders with some pixel values.

padded_imgs = [T.Pad(padding=padding)(orig_img) for padding in (3, 10, 30, 50)]
plot(padded_imgs)
Original image

Resize

The Resize transform (see also resize()) resizes an image.

resized_imgs = [T.Resize(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)]
plot(resized_imgs)
Original image

CenterCrop

The CenterCrop transform (see also center_crop()) crops the given image at the center.

center_crops = [T.CenterCrop(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)]
plot(center_crops)
Original image

FiveCrop

The FiveCrop transform (see also five_crop()) crops the given image into four corners and the central crop.

Original image

Grayscale

The Grayscale transform (see also to_grayscale()) converts an image to grayscale

gray_img = T.Grayscale()(orig_img)
plot([gray_img], cmap='gray')
Original image

Random transforms

The following transforms are random, which means that the same transfomer instance will produce different result each time it transforms a given image.

ColorJitter

The ColorJitter transform randomly changes the brightness, saturation, and other properties of an image.

jitter = T.ColorJitter(brightness=.5, hue=.3)
jitted_imgs = [jitter(orig_img) for _ in range(4)]
plot(jitted_imgs)
Original image

GaussianBlur

The GaussianBlur transform (see also gaussian_blur()) performs gaussian blur transform on an image.

blurrer = T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))
blurred_imgs = [blurrer(orig_img) for _ in range(4)]
plot(blurred_imgs)
Original image

RandomPerspective

The RandomPerspective transform (see also perspective()) performs random perspective transform on an image.

perspective_transformer = T.RandomPerspective(distortion_scale=0.6, p=1.0)
perspective_imgs = [perspective_transformer(orig_img) for _ in range(4)]
plot(perspective_imgs)
Original image

RandomRotation

The RandomRotation transform (see also rotate()) rotates an image with random angle.

rotater = T.RandomRotation(degrees=(0, 180))
rotated_imgs = [rotater(orig_img) for _ in range(4)]
plot(rotated_imgs)
Original image

RandomAffine

The RandomAffine transform (see also affine()) performs random affine transform on an image.

affine_transfomer = T.RandomAffine(degrees=(30, 70), translate=(0.1, 0.3), scale=(0.5, 0.75))
affine_imgs = [affine_transfomer(orig_img) for _ in range(4)]
plot(affine_imgs)
Original image

RandomCrop

The RandomCrop transform (see also crop()) crops an image at a random location.

cropper = T.RandomCrop(size=(128, 128))
crops = [cropper(orig_img) for _ in range(4)]
plot(crops)
Original image

RandomResizedCrop

The RandomResizedCrop transform (see also resized_crop()) crops an image at a random location, and then resizes the crop to a given size.

resize_cropper = T.RandomResizedCrop(size=(32, 32))
resized_crops = [resize_cropper(orig_img) for _ in range(4)]
plot(resized_crops)
Original image

RandomInvert

The RandomInvert transform (see also invert()) randomly inverts the colors of the given image.

inverter = T.RandomInvert()
invertered_imgs = [inverter(orig_img) for _ in range(4)]
plot(invertered_imgs)
Original image

RandomPosterize

The RandomPosterize transform (see also posterize()) randomly posterizes the image by reducing the number of bits of each color channel.

posterizer = T.RandomPosterize(bits=2)
posterized_imgs = [posterizer(orig_img) for _ in range(4)]
plot(posterized_imgs)
Original image

RandomSolarize

The RandomSolarize transform (see also solarize()) randomly solarizes the image by inverting all pixel values above the threshold.

solarizer = T.RandomSolarize(threshold=192.0)
solarized_imgs = [solarizer(orig_img) for _ in range(4)]
plot(solarized_imgs)
Original image

RandomAdjustSharpness

The RandomAdjustSharpness transform (see also adjust_sharpness()) randomly adjusts the sharpness of the given image.

sharpness_adjuster = T.RandomAdjustSharpness(sharpness_factor=2)
sharpened_imgs = [sharpness_adjuster(orig_img) for _ in range(4)]
plot(sharpened_imgs)
Original image

RandomAutocontrast

The RandomAutocontrast transform (see also autocontrast()) randomly applies autocontrast to the given image.

autocontraster = T.RandomAutocontrast()
autocontrasted_imgs = [autocontraster(orig_img) for _ in range(4)]
plot(autocontrasted_imgs)
Original image

RandomEqualize

The RandomEqualize transform (see also equalize()) randomly equalizes the histogram of the given image.

equalizer = T.RandomEqualize()
equalized_imgs = [equalizer(orig_img) for _ in range(4)]
plot(equalized_imgs)
Original image

AutoAugment

The AutoAugment transform automatically augments data based on a given auto-augmentation policy. See AutoAugmentPolicy for the available policies.

policies = [T.AutoAugmentPolicy.CIFAR10, T.AutoAugmentPolicy.IMAGENET, T.AutoAugmentPolicy.SVHN]
augmenters = [T.AutoAugment(policy) for policy in policies]
imgs = [
    [augmenter(orig_img) for _ in range(4)]
    for augmenter in augmenters
]
row_title = [str(policy).split('.')[-1] for policy in policies]
plot(imgs, row_title=row_title)
Original image

RandAugment

The RandAugment transform automatically augments the data.

augmenter = T.RandAugment()
imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs)
Original image

TrivialAugmentWide

The TrivialAugmentWide transform automatically augments the data.

augmenter = T.TrivialAugmentWide()
imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs)
Original image

AugMix

The AugMix transform automatically augments the data.

augmenter = T.AugMix()
imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs)
Original image

Randomly-applied transforms

Some transforms are randomly-applied given a probability p. That is, the transformed image may actually be the same as the original one, even when called with the same transformer instance!

RandomHorizontalFlip

The RandomHorizontalFlip transform (see also hflip()) performs horizontal flip of an image, with a given probability.

hflipper = T.RandomHorizontalFlip(p=0.5)
transformed_imgs = [hflipper(orig_img) for _ in range(4)]
plot(transformed_imgs)
Original image

RandomVerticalFlip

The RandomVerticalFlip transform (see also vflip()) performs vertical flip of an image, with a given probability.

vflipper = T.RandomVerticalFlip(p=0.5)
transformed_imgs = [vflipper(orig_img) for _ in range(4)]
plot(transformed_imgs)
Original image

RandomApply

The RandomApply transform randomly applies a list of transforms, with a given probability.

applier = T.RandomApply(transforms=[T.RandomCrop(size=(64, 64))], p=0.5)
transformed_imgs = [applier(orig_img) for _ in range(4)]
plot(transformed_imgs)
Original image

Total running time of the script: ( 0 minutes 7.697 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources