Note
Click here to download the full example code
Illustration of transforms¶
This example illustrates the various transforms available in the torchvision.transforms module.
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms as T
plt.rcParams["savefig.bbox"] = 'tight'
orig_img = Image.open(Path('assets') / 'astronaut.jpg')
# if you change the seed, make sure that the randomly-applied transforms
# properly show that the image can be both transformed and *not* transformed!
torch.manual_seed(0)
def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
if not isinstance(imgs[0], list):
# Make a 2d grid even if there's just 1 row
imgs = [imgs]
num_rows = len(imgs)
num_cols = len(imgs[0]) + with_orig
fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
for row_idx, row in enumerate(imgs):
row = [orig_img] + row if with_orig else row
for col_idx, img in enumerate(row):
ax = axs[row_idx, col_idx]
ax.imshow(np.asarray(img), **imshow_kwargs)
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
if with_orig:
axs[0, 0].set(title='Original image')
axs[0, 0].title.set_size(8)
if row_title is not None:
for row_idx in range(num_rows):
axs[row_idx, 0].set(ylabel=row_title[row_idx])
plt.tight_layout()
Pad¶
The Pad
transform
(see also pad()
)
fills image borders with some pixel values.
padded_imgs = [T.Pad(padding=padding)(orig_img) for padding in (3, 10, 30, 50)]
plot(padded_imgs)
Resize¶
The Resize
transform
(see also resize()
)
resizes an image.
resized_imgs = [T.Resize(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)]
plot(resized_imgs)
CenterCrop¶
The CenterCrop
transform
(see also center_crop()
)
crops the given image at the center.
center_crops = [T.CenterCrop(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)]
plot(center_crops)
FiveCrop¶
The FiveCrop
transform
(see also five_crop()
)
crops the given image into four corners and the central crop.
(top_left, top_right, bottom_left, bottom_right, center) = T.FiveCrop(size=(100, 100))(orig_img)
plot([top_left, top_right, bottom_left, bottom_right, center])
Grayscale¶
The Grayscale
transform
(see also to_grayscale()
)
converts an image to grayscale
gray_img = T.Grayscale()(orig_img)
plot([gray_img], cmap='gray')
Random transforms¶
The following transforms are random, which means that the same transfomer instance will produce different result each time it transforms a given image.
ColorJitter¶
The ColorJitter
transform
randomly changes the brightness, saturation, and other properties of an image.
jitter = T.ColorJitter(brightness=.5, hue=.3)
jitted_imgs = [jitter(orig_img) for _ in range(4)]
plot(jitted_imgs)
GaussianBlur¶
The GaussianBlur
transform
(see also gaussian_blur()
)
performs gaussian blur transform on an image.
blurrer = T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))
blurred_imgs = [blurrer(orig_img) for _ in range(4)]
plot(blurred_imgs)
RandomPerspective¶
The RandomPerspective
transform
(see also perspective()
)
performs random perspective transform on an image.
perspective_transformer = T.RandomPerspective(distortion_scale=0.6, p=1.0)
perspective_imgs = [perspective_transformer(orig_img) for _ in range(4)]
plot(perspective_imgs)
Out:
/home/matti/miniconda3/envs/pytorch-test/lib/python3.8/site-packages/torchvision/transforms/functional.py:594: UserWarning: torch.lstsq is deprecated in favor of torch.linalg.lstsq and will be removed in a future PyTorch release.
torch.linalg.lstsq has reversed arguments and does not return the QR decomposition in the returned tuple (although it returns other information about the problem).
To get the qr decomposition consider using torch.linalg.qr.
The returned solution in torch.lstsq stored the residuals of the solution in the last m - n columns of the returned value whenever m > n. In torch.linalg.lstsq, the residuals in the field 'residuals' of the returned named tuple.
The unpacking of the solution, as in
X, _ = torch.lstsq(B, A).solution[:A.size(1)]
should be replaced with
X = torch.linalg.lstsq(A, B).solution (Triggered internally at /opt/conda/conda-bld/pytorch_1623448216815/work/aten/src/ATen/LegacyTHFunctionsCPU.cpp:389.)
res = torch.lstsq(b_matrix, a_matrix)[0]
RandomRotation¶
The RandomRotation
transform
(see also rotate()
)
rotates an image with random angle.
rotater = T.RandomRotation(degrees=(0, 180))
rotated_imgs = [rotater(orig_img) for _ in range(4)]
plot(rotated_imgs)
RandomAffine¶
The RandomAffine
transform
(see also affine()
)
performs random affine transform on an image.
affine_transfomer = T.RandomAffine(degrees=(30, 70), translate=(0.1, 0.3), scale=(0.5, 0.75))
affine_imgs = [affine_transfomer(orig_img) for _ in range(4)]
plot(affine_imgs)
RandomCrop¶
The RandomCrop
transform
(see also crop()
)
crops an image at a random location.
cropper = T.RandomCrop(size=(128, 128))
crops = [cropper(orig_img) for _ in range(4)]
plot(crops)
RandomResizedCrop¶
The RandomResizedCrop
transform
(see also resized_crop()
)
crops an image at a random location, and then resizes the crop to a given
size.
resize_cropper = T.RandomResizedCrop(size=(32, 32))
resized_crops = [resize_cropper(orig_img) for _ in range(4)]
plot(resized_crops)
RandomInvert¶
The RandomInvert
transform
(see also invert()
)
randomly inverts the colors of the given image.
inverter = T.RandomInvert()
invertered_imgs = [inverter(orig_img) for _ in range(4)]
plot(invertered_imgs)
RandomPosterize¶
The RandomPosterize
transform
(see also posterize()
)
randomly posterizes the image by reducing the number of bits
of each color channel.
posterizer = T.RandomPosterize(bits=2)
posterized_imgs = [posterizer(orig_img) for _ in range(4)]
plot(posterized_imgs)
RandomSolarize¶
The RandomSolarize
transform
(see also solarize()
)
randomly solarizes the image by inverting all pixel values above
the threshold.
solarizer = T.RandomSolarize(threshold=192.0)
solarized_imgs = [solarizer(orig_img) for _ in range(4)]
plot(solarized_imgs)
RandomAdjustSharpness¶
The RandomAdjustSharpness
transform
(see also adjust_sharpness()
)
randomly adjusts the sharpness of the given image.
sharpness_adjuster = T.RandomAdjustSharpness(sharpness_factor=2)
sharpened_imgs = [sharpness_adjuster(orig_img) for _ in range(4)]
plot(sharpened_imgs)
RandomAutocontrast¶
The RandomAutocontrast
transform
(see also autocontrast()
)
randomly applies autocontrast to the given image.
autocontraster = T.RandomAutocontrast()
autocontrasted_imgs = [autocontraster(orig_img) for _ in range(4)]
plot(autocontrasted_imgs)
RandomEqualize¶
The RandomEqualize
transform
(see also equalize()
)
randomly equalizes the histogram of the given image.
equalizer = T.RandomEqualize()
equalized_imgs = [equalizer(orig_img) for _ in range(4)]
plot(equalized_imgs)
AutoAugment¶
The AutoAugment
transform
automatically augments data based on a given auto-augmentation policy.
See AutoAugmentPolicy
for the available policies.
policies = [T.AutoAugmentPolicy.CIFAR10, T.AutoAugmentPolicy.IMAGENET, T.AutoAugmentPolicy.SVHN]
augmenters = [T.AutoAugment(policy) for policy in policies]
imgs = [
[augmenter(orig_img) for _ in range(4)]
for augmenter in augmenters
]
row_title = [str(policy).split('.')[-1] for policy in policies]
plot(imgs, row_title=row_title)
Randomly-applied transforms¶
Some transforms are randomly-applied given a probability p
. That is, the
transformed image may actually be the same as the original one, even when
called with the same transformer instance!
RandomHorizontalFlip¶
The RandomHorizontalFlip
transform
(see also hflip()
)
performs horizontal flip of an image, with a given probability.
hflipper = T.RandomHorizontalFlip(p=0.5)
transformed_imgs = [hflipper(orig_img) for _ in range(4)]
plot(transformed_imgs)
RandomVerticalFlip¶
The RandomVerticalFlip
transform
(see also vflip()
)
performs vertical flip of an image, with a given probability.
vflipper = T.RandomVerticalFlip(p=0.5)
transformed_imgs = [vflipper(orig_img) for _ in range(4)]
plot(transformed_imgs)
RandomApply¶
The RandomApply
transform
randomly applies a list of transforms, with a given probability.
applier = T.RandomApply(transforms=[T.RandomCrop(size=(64, 64))], p=0.5)
transformed_imgs = [applier(orig_img) for _ in range(4)]
plot(transformed_imgs)
Total running time of the script: ( 0 minutes 5.275 seconds)