import torch
import math
import random
from PIL import Image
try:
import accimage
except ImportError:
accimage = None
import numpy as np
import numbers
import types
from collections.abc import Sequence, Iterable
import warnings
from . import functional as F
__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale",
"CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop",
"RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
"LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
"RandomPerspective", "RandomErasing"]
_pil_interpolation_to_str = {
Image.NEAREST: 'PIL.Image.NEAREST',
Image.BILINEAR: 'PIL.Image.BILINEAR',
Image.BICUBIC: 'PIL.Image.BICUBIC',
Image.LANCZOS: 'PIL.Image.LANCZOS',
Image.HAMMING: 'PIL.Image.HAMMING',
Image.BOX: 'PIL.Image.BOX',
}
def _get_image_size(img):
if F._is_pil_image(img):
return img.size
elif isinstance(img, torch.Tensor) and img.dim() > 2:
return img.shape[-2:][::-1]
else:
raise TypeError("Unexpected type {}".format(type(img)))
[docs]class Compose(object):
"""Composes several transforms together.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.ToTensor(),
>>> ])
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img):
for t in self.transforms:
img = t(img)
return img
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
[docs]class ToTensor(object):
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
or if the numpy.ndarray has dtype = np.uint8
In the other cases, tensors are returned without scaling.
"""
[docs] def __call__(self, pic):
"""
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
return F.to_tensor(pic)
def __repr__(self):
return self.__class__.__name__ + '()'
class PILToTensor(object):
"""Convert a ``PIL Image`` to a tensor of the same type.
Converts a PIL Image (H x W x C) to a torch.Tensor of shape (C x H x W).
"""
def __call__(self, pic):
"""
Args:
pic (PIL Image): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
return F.pil_to_tensor(pic)
def __repr__(self):
return self.__class__.__name__ + '()'
class ConvertImageDtype(object):
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly
Args:
dtype (torch.dtype): Desired data type of the output
.. note::
When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
If converted back and forth, this mismatch has no effect.
Raises:
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
of the integer ``dtype``.
"""
def __init__(self, dtype: torch.dtype) -> None:
self.dtype = dtype
def __call__(self, image: torch.Tensor) -> torch.Tensor:
return F.convert_image_dtype(image, self.dtype)
[docs]class ToPILImage(object):
"""Convert a tensor or an ndarray to PIL Image.
Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
H x W x C to a PIL Image while preserving the value range.
Args:
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
If ``mode`` is ``None`` (default) there are some assumptions made about the input data:
- If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``.
- If the input has 3 channels, the ``mode`` is assumed to be ``RGB``.
- If the input has 2 channels, the ``mode`` is assumed to be ``LA``.
- If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``,
``short``).
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
"""
def __init__(self, mode=None):
self.mode = mode
[docs] def __call__(self, pic):
"""
Args:
pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
Returns:
PIL Image: Image converted to PIL Image.
"""
return F.to_pil_image(pic, self.mode)
def __repr__(self):
format_string = self.__class__.__name__ + '('
if self.mode is not None:
format_string += 'mode={0}'.format(self.mode)
format_string += ')'
return format_string
[docs]class Normalize(object):
"""Normalize a tensor image with mean and standard deviation.
Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
channels, this transform will normalize each channel of the input
``torch.*Tensor`` i.e.,
``output[channel] = (input[channel] - mean[channel]) / std[channel]``
.. note::
This transform acts out of place, i.e., it does not mutate the input tensor.
Args:
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel.
inplace(bool,optional): Bool to make this operation in-place.
"""
def __init__(self, mean, std, inplace=False):
self.mean = mean
self.std = std
self.inplace = inplace
[docs] def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
Returns:
Tensor: Normalized Tensor image.
"""
return F.normalize(tensor, self.mean, self.std, self.inplace)
def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
[docs]class Resize(object):
"""Resize the input PIL Image to the given size.
Args:
size (sequence or int): Desired output size. If size is a sequence like
(h, w), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size)
interpolation (int, optional): Desired interpolation. Default is
``PIL.Image.BILINEAR``
"""
def __init__(self, size, interpolation=Image.BILINEAR):
assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)
self.size = size
self.interpolation = interpolation
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be scaled.
Returns:
PIL Image: Rescaled image.
"""
return F.resize(img, self.size, self.interpolation)
def __repr__(self):
interpolate_str = _pil_interpolation_to_str[self.interpolation]
return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
[docs]class Scale(Resize):
"""
Note: This transform is deprecated in favor of Resize.
"""
def __init__(self, *args, **kwargs):
warnings.warn("The use of the transforms.Scale transform is deprecated, " +
"please use transforms.Resize instead.")
super(Scale, self).__init__(*args, **kwargs)
[docs]class CenterCrop(object):
"""Crops the given PIL Image at the center.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
"""
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be cropped.
Returns:
PIL Image: Cropped image.
"""
return F.center_crop(img, self.size)
def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size)
[docs]class Pad(object):
"""Pad the given PIL Image on all sides with the given "pad" value.
Args:
padding (int or tuple): Padding on each border. If a single int is provided this
is used to pad all borders. If tuple of length 2 is provided this is the padding
on left/right and top/bottom respectively. If a tuple of length 4 is provided
this is the padding for the left, top, right and bottom borders
respectively.
fill (int or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
length 3, it is used to fill R, G, B channels respectively.
This value is only used when the padding_mode is constant
padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
Default is constant.
- constant: pads with a constant value, this value is specified with fill
- edge: pads with the last value at the edge of the image
- reflect: pads with reflection of image without repeating the last value on the edge
For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
will result in [3, 2, 1, 2, 3, 4, 3, 2]
- symmetric: pads with reflection of image repeating the last value on the edge
For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
will result in [2, 1, 1, 2, 3, 4, 4, 3]
"""
def __init__(self, padding, fill=0, padding_mode='constant'):
assert isinstance(padding, (numbers.Number, tuple))
assert isinstance(fill, (numbers.Number, str, tuple))
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
if isinstance(padding, Sequence) and len(padding) not in [2, 4]:
raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
"{} element tuple".format(len(padding)))
self.padding = padding
self.fill = fill
self.padding_mode = padding_mode
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be padded.
Returns:
PIL Image: Padded image.
"""
return F.pad(img, self.padding, self.fill, self.padding_mode)
def __repr__(self):
return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\
format(self.padding, self.fill, self.padding_mode)
[docs]class Lambda(object):
"""Apply a user-defined lambda as a transform.
Args:
lambd (function): Lambda/function to be used for transform.
"""
def __init__(self, lambd):
assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
self.lambd = lambd
def __call__(self, img):
return self.lambd(img)
def __repr__(self):
return self.__class__.__name__ + '()'
class RandomTransforms(object):
"""Base class for a list of transformations with randomness
Args:
transforms (list or tuple): list of transformations
"""
def __init__(self, transforms):
assert isinstance(transforms, (list, tuple))
self.transforms = transforms
def __call__(self, *args, **kwargs):
raise NotImplementedError()
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
[docs]class RandomApply(RandomTransforms):
"""Apply randomly a list of transformations with a given probability
Args:
transforms (list or tuple): list of transformations
p (float): probability
"""
def __init__(self, transforms, p=0.5):
super(RandomApply, self).__init__(transforms)
self.p = p
def __call__(self, img):
if self.p < random.random():
return img
for t in self.transforms:
img = t(img)
return img
def __repr__(self):
format_string = self.__class__.__name__ + '('
format_string += '\n p={}'.format(self.p)
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
[docs]class RandomOrder(RandomTransforms):
"""Apply a list of transformations in a random order
"""
def __call__(self, img):
order = list(range(len(self.transforms)))
random.shuffle(order)
for i in order:
img = self.transforms[i](img)
return img
[docs]class RandomChoice(RandomTransforms):
"""Apply single transformation randomly picked from a list
"""
def __call__(self, img):
t = random.choice(self.transforms)
return t(img)
[docs]class RandomCrop(object):
"""Crop the given PIL Image at a random location.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
padding (int or sequence, optional): Optional padding on each border
of the image. Default is None, i.e no padding. If a sequence of length
4 is provided, it is used to pad left, top, right, bottom borders
respectively. If a sequence of length 2 is provided, it is used to
pad left/right, top/bottom borders, respectively.
pad_if_needed (boolean): It will pad the image if smaller than the
desired size to avoid raising an exception. Since cropping is done
after padding, the padding seems to be done at a random offset.
fill: Pixel fill value for constant fill. Default is 0. If a tuple of
length 3, it is used to fill R, G, B channels respectively.
This value is only used when the padding_mode is constant
padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
- constant: pads with a constant value, this value is specified with fill
- edge: pads with the last value on the edge of the image
- reflect: pads with reflection of image (without repeating the last value on the edge)
padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
will result in [3, 2, 1, 2, 3, 4, 3, 2]
- symmetric: pads with reflection of image (repeating the last value on the edge)
padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
will result in [2, 1, 1, 2, 3, 4, 4, 3]
"""
def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
self.padding = padding
self.pad_if_needed = pad_if_needed
self.fill = fill
self.padding_mode = padding_mode
@staticmethod
def get_params(img, output_size):
"""Get parameters for ``crop`` for a random crop.
Args:
img (PIL Image): Image to be cropped.
output_size (tuple): Expected output size of the crop.
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
"""
w, h = _get_image_size(img)
th, tw = output_size
if w == tw and h == th:
return 0, 0, h, w
i = random.randint(0, h - th)
j = random.randint(0, w - tw)
return i, j, th, tw
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be cropped.
Returns:
PIL Image: Cropped image.
"""
if self.padding is not None:
img = F.pad(img, self.padding, self.fill, self.padding_mode)
# pad the width if needed
if self.pad_if_needed and img.size[0] < self.size[1]:
img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode)
# pad the height if needed
if self.pad_if_needed and img.size[1] < self.size[0]:
img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode)
i, j, h, w = self.get_params(img, self.size)
return F.crop(img, i, j, h, w)
def __repr__(self):
return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)
[docs]class RandomHorizontalFlip(torch.nn.Module):
"""Horizontally flip the given image randomly with a given probability.
The image can be a PIL Image or a torch Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading
dimensions
Args:
p (float): probability of the image being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
super().__init__()
self.p = p
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be flipped.
Returns:
PIL Image or Tensor: Randomly flipped image.
"""
if torch.rand(1) < self.p:
return F.hflip(img)
return img
def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)
[docs]class RandomVerticalFlip(torch.nn.Module):
"""Vertically flip the given PIL Image randomly with a given probability.
The image can be a PIL Image or a torch Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading
dimensions
Args:
p (float): probability of the image being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
super().__init__()
self.p = p
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be flipped.
Returns:
PIL Image or Tensor: Randomly flipped image.
"""
if torch.rand(1) < self.p:
return F.vflip(img)
return img
def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)
[docs]class RandomPerspective(object):
"""Performs Perspective transformation of the given PIL Image randomly with a given probability.
Args:
interpolation : Default- Image.BICUBIC
p (float): probability of the image being perspectively transformed. Default value is 0.5
distortion_scale(float): it controls the degree of distortion and ranges from 0 to 1. Default value is 0.5.
fill (3-tuple or int): RGB pixel fill value for area outside the rotated image.
If int, it is used for all channels respectively. Default value is 0.
"""
def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BICUBIC, fill=0):
self.p = p
self.interpolation = interpolation
self.distortion_scale = distortion_scale
self.fill = fill
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be Perspectively transformed.
Returns:
PIL Image: Random perspectivley transformed image.
"""
if not F._is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if random.random() < self.p:
width, height = img.size
startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
return F.perspective(img, startpoints, endpoints, self.interpolation, self.fill)
return img
@staticmethod
def get_params(width, height, distortion_scale):
"""Get parameters for ``perspective`` for a random perspective transform.
Args:
width : width of the image.
height : height of the image.
Returns:
List containing [top-left, top-right, bottom-right, bottom-left] of the original image,
List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image.
"""
half_height = int(height / 2)
half_width = int(width / 2)
topleft = (random.randint(0, int(distortion_scale * half_width)),
random.randint(0, int(distortion_scale * half_height)))
topright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1),
random.randint(0, int(distortion_scale * half_height)))
botright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1),
random.randint(height - int(distortion_scale * half_height) - 1, height - 1))
botleft = (random.randint(0, int(distortion_scale * half_width)),
random.randint(height - int(distortion_scale * half_height) - 1, height - 1))
startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)]
endpoints = [topleft, topright, botright, botleft]
return startpoints, endpoints
def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)
[docs]class RandomResizedCrop(object):
"""Crop the given PIL Image to random size and aspect ratio.
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
is finally resized to given size.
This is popularly used to train the Inception networks.
Args:
size: expected output size of each edge
scale: range of size of the origin size cropped
ratio: range of aspect ratio of the origin aspect ratio cropped
interpolation: Default: PIL.Image.BILINEAR
"""
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):
if isinstance(size, (tuple, list)):
self.size = size
else:
self.size = (size, size)
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("range should be of kind (min, max)")
self.interpolation = interpolation
self.scale = scale
self.ratio = ratio
@staticmethod
def get_params(img, scale, ratio):
"""Get parameters for ``crop`` for a random sized crop.
Args:
img (PIL Image): Image to be cropped.
scale (tuple): range of size of the origin size cropped
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop.
"""
width, height = _get_image_size(img)
area = height * width
for _ in range(10):
target_area = random.uniform(*scale) * area
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if 0 < w <= width and 0 < h <= height:
i = random.randint(0, height - h)
j = random.randint(0, width - w)
return i, j, h, w
# Fallback to central crop
in_ratio = float(width) / float(height)
if (in_ratio < min(ratio)):
w = width
h = int(round(w / min(ratio)))
elif (in_ratio > max(ratio)):
h = height
w = int(round(h * max(ratio)))
else: # whole image
w = width
h = height
i = (height - h) // 2
j = (width - w) // 2
return i, j, h, w
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be cropped and resized.
Returns:
PIL Image: Randomly cropped and resized image.
"""
i, j, h, w = self.get_params(img, self.scale, self.ratio)
return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
def __repr__(self):
interpolate_str = _pil_interpolation_to_str[self.interpolation]
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
format_string += ', interpolation={0})'.format(interpolate_str)
return format_string
[docs]class RandomSizedCrop(RandomResizedCrop):
"""
Note: This transform is deprecated in favor of RandomResizedCrop.
"""
def __init__(self, *args, **kwargs):
warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " +
"please use transforms.RandomResizedCrop instead.")
super(RandomSizedCrop, self).__init__(*args, **kwargs)
[docs]class FiveCrop(object):
"""Crop the given PIL Image into four corners and the central crop
.. Note::
This transform returns a tuple of images and there may be a mismatch in the number of
inputs and targets your Dataset returns. See below for an example of how to deal with
this.
Args:
size (sequence or int): Desired output size of the crop. If size is an ``int``
instead of sequence like (h, w), a square crop of size (size, size) is made.
Example:
>>> transform = Compose([
>>> FiveCrop(size), # this is a list of PIL Images
>>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
>>> ])
>>> #In your test loop you can do the following:
>>> input, target = batch # input is a 5d tensor, target is 2d
>>> bs, ncrops, c, h, w = input.size()
>>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
>>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
"""
def __init__(self, size):
self.size = size
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
self.size = size
def __call__(self, img):
return F.five_crop(img, self.size)
def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size)
[docs]class TenCrop(object):
"""Crop the given PIL Image into four corners and the central crop plus the flipped version of
these (horizontal flipping is used by default)
.. Note::
This transform returns a tuple of images and there may be a mismatch in the number of
inputs and targets your Dataset returns. See below for an example of how to deal with
this.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
vertical_flip (bool): Use vertical flipping instead of horizontal
Example:
>>> transform = Compose([
>>> TenCrop(size), # this is a list of PIL Images
>>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
>>> ])
>>> #In your test loop you can do the following:
>>> input, target = batch # input is a 5d tensor, target is 2d
>>> bs, ncrops, c, h, w = input.size()
>>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
>>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
"""
def __init__(self, size, vertical_flip=False):
self.size = size
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
self.size = size
self.vertical_flip = vertical_flip
def __call__(self, img):
return F.ten_crop(img, self.size, self.vertical_flip)
def __repr__(self):
return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip)
[docs]class ColorJitter(torch.nn.Module):
"""Randomly change the brightness, contrast and saturation of an image.
Args:
brightness (float or tuple of float (min, max)): How much to jitter brightness.
brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
or the given [min, max]. Should be non negative numbers.
contrast (float or tuple of float (min, max)): How much to jitter contrast.
contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
or the given [min, max]. Should be non negative numbers.
saturation (float or tuple of float (min, max)): How much to jitter saturation.
saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
or the given [min, max]. Should be non negative numbers.
hue (float or tuple of float (min, max)): How much to jitter hue.
hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
"""
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
super().__init__()
self.brightness = self._check_input(brightness, 'brightness')
self.contrast = self._check_input(contrast, 'contrast')
self.saturation = self._check_input(saturation, 'saturation')
self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
clip_first_on_zero=False)
@torch.jit.unused
def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
if isinstance(value, numbers.Number):
if value < 0:
raise ValueError("If {} is a single number, it must be non negative.".format(name))
value = [center - float(value), center + float(value)]
if clip_first_on_zero:
value[0] = max(value[0], 0.0)
elif isinstance(value, (tuple, list)) and len(value) == 2:
if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError("{} values should be between {}".format(name, bound))
else:
raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))
# if value is 0 or (1., 1.) for brightness/contrast/saturation
# or (0., 0.) for hue, do nothing
if value[0] == value[1] == center:
value = None
return value
@staticmethod
@torch.jit.unused
def get_params(brightness, contrast, saturation, hue):
"""Get a randomized transform to be applied on image.
Arguments are same as that of __init__.
Returns:
Transform which randomly adjusts brightness, contrast and
saturation in a random order.
"""
transforms = []
if brightness is not None:
brightness_factor = random.uniform(brightness[0], brightness[1])
transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor)))
if contrast is not None:
contrast_factor = random.uniform(contrast[0], contrast[1])
transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))
if saturation is not None:
saturation_factor = random.uniform(saturation[0], saturation[1])
transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor)))
if hue is not None:
hue_factor = random.uniform(hue[0], hue[1])
transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))
random.shuffle(transforms)
transform = Compose(transforms)
return transform
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Input image.
Returns:
PIL Image or Tensor: Color jittered image.
"""
fn_idx = torch.randperm(4)
for fn_id in fn_idx:
if fn_id == 0 and self.brightness is not None:
brightness = self.brightness
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
img = F.adjust_brightness(img, brightness_factor)
if fn_id == 1 and self.contrast is not None:
contrast = self.contrast
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
img = F.adjust_contrast(img, contrast_factor)
if fn_id == 2 and self.saturation is not None:
saturation = self.saturation
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
img = F.adjust_saturation(img, saturation_factor)
if fn_id == 3 and self.hue is not None:
hue = self.hue
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
img = F.adjust_hue(img, hue_factor)
return img
def __repr__(self):
format_string = self.__class__.__name__ + '('
format_string += 'brightness={0}'.format(self.brightness)
format_string += ', contrast={0}'.format(self.contrast)
format_string += ', saturation={0}'.format(self.saturation)
format_string += ', hue={0})'.format(self.hue)
return format_string
[docs]class RandomRotation(object):
"""Rotate the image by angle.
Args:
degrees (sequence or float or int): Range of degrees to select from.
If degrees is a number instead of sequence like (min, max), the range of degrees
will be (-degrees, +degrees).
resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
An optional resampling filter. See `filters`_ for more information.
If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
expand (bool, optional): Optional expansion flag.
If true, expands the output to make it large enough to hold the entire rotated image.
If false or omitted, make the output image the same size as the input image.
Note that the expand flag assumes rotation around the center and no translation.
center (2-tuple, optional): Optional center of rotation.
Origin is the upper left corner.
Default is the center of the image.
fill (n-tuple or int or float): Pixel fill value for area outside the rotated
image. If int or float, the value is used for all bands respectively.
Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``.
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
"""
def __init__(self, degrees, resample=False, expand=False, center=None, fill=None):
if isinstance(degrees, numbers.Number):
if degrees < 0:
raise ValueError("If degrees is a single number, it must be positive.")
self.degrees = (-degrees, degrees)
else:
if len(degrees) != 2:
raise ValueError("If degrees is a sequence, it must be of len 2.")
self.degrees = degrees
self.resample = resample
self.expand = expand
self.center = center
self.fill = fill
@staticmethod
def get_params(degrees):
"""Get parameters for ``rotate`` for a random rotation.
Returns:
sequence: params to be passed to ``rotate`` for random rotation.
"""
angle = random.uniform(degrees[0], degrees[1])
return angle
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be rotated.
Returns:
PIL Image: Rotated image.
"""
angle = self.get_params(self.degrees)
return F.rotate(img, angle, self.resample, self.expand, self.center, self.fill)
def __repr__(self):
format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees)
format_string += ', resample={0}'.format(self.resample)
format_string += ', expand={0}'.format(self.expand)
if self.center is not None:
format_string += ', center={0}'.format(self.center)
format_string += ')'
return format_string
[docs]class RandomAffine(object):
"""Random affine transformation of the image keeping center invariant
Args:
degrees (sequence or float or int): Range of degrees to select from.
If degrees is a number instead of sequence like (min, max), the range of degrees
will be (-degrees, +degrees). Set to 0 to deactivate rotations.
translate (tuple, optional): tuple of maximum absolute fraction for horizontal
and vertical translations. For example translate=(a, b), then horizontal shift
is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
randomly sampled from the range a <= scale <= b. Will keep original scale by default.
shear (sequence or float or int, optional): Range of degrees to select from.
If shear is a number, a shear parallel to the x axis in the range (-shear, +shear)
will be apllied. Else if shear is a tuple or list of 2 values a shear parallel to the x axis in the
range (shear[0], shear[1]) will be applied. Else if shear is a tuple or list of 4 values,
a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
Will not apply shear by default
resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
An optional resampling filter. See `filters`_ for more information.
If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
fillcolor (tuple or int): Optional fill color (Tuple for RGB Image And int for grayscale) for the area
outside the transform in the output image.(Pillow>=5.0.0)
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
"""
def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0):
if isinstance(degrees, numbers.Number):
if degrees < 0:
raise ValueError("If degrees is a single number, it must be positive.")
self.degrees = (-degrees, degrees)
else:
assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \
"degrees should be a list or tuple and it must be of length 2."
self.degrees = degrees
if translate is not None:
assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
"translate should be a list or tuple and it must be of length 2."
for t in translate:
if not (0.0 <= t <= 1.0):
raise ValueError("translation values should be between 0 and 1")
self.translate = translate
if scale is not None:
assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
"scale should be a list or tuple and it must be of length 2."
for s in scale:
if s <= 0:
raise ValueError("scale values should be positive")
self.scale = scale
if shear is not None:
if isinstance(shear, numbers.Number):
if shear < 0:
raise ValueError("If shear is a single number, it must be positive.")
self.shear = (-shear, shear)
else:
assert isinstance(shear, (tuple, list)) and \
(len(shear) == 2 or len(shear) == 4), \
"shear should be a list or tuple and it must be of length 2 or 4."
# X-Axis shear with [min, max]
if len(shear) == 2:
self.shear = [shear[0], shear[1], 0., 0.]
elif len(shear) == 4:
self.shear = [s for s in shear]
else:
self.shear = shear
self.resample = resample
self.fillcolor = fillcolor
@staticmethod
def get_params(degrees, translate, scale_ranges, shears, img_size):
"""Get parameters for affine transformation
Returns:
sequence: params to be passed to the affine transformation
"""
angle = random.uniform(degrees[0], degrees[1])
if translate is not None:
max_dx = translate[0] * img_size[0]
max_dy = translate[1] * img_size[1]
translations = (np.round(random.uniform(-max_dx, max_dx)),
np.round(random.uniform(-max_dy, max_dy)))
else:
translations = (0, 0)
if scale_ranges is not None:
scale = random.uniform(scale_ranges[0], scale_ranges[1])
else:
scale = 1.0
if shears is not None:
if len(shears) == 2:
shear = [random.uniform(shears[0], shears[1]), 0.]
elif len(shears) == 4:
shear = [random.uniform(shears[0], shears[1]),
random.uniform(shears[2], shears[3])]
else:
shear = 0.0
return angle, translations, scale, shear
def __call__(self, img):
"""
img (PIL Image): Image to be transformed.
Returns:
PIL Image: Affine transformed image.
"""
ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size)
return F.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor)
def __repr__(self):
s = '{name}(degrees={degrees}'
if self.translate is not None:
s += ', translate={translate}'
if self.scale is not None:
s += ', scale={scale}'
if self.shear is not None:
s += ', shear={shear}'
if self.resample > 0:
s += ', resample={resample}'
if self.fillcolor != 0:
s += ', fillcolor={fillcolor}'
s += ')'
d = dict(self.__dict__)
d['resample'] = _pil_interpolation_to_str[d['resample']]
return s.format(name=self.__class__.__name__, **d)
[docs]class Grayscale(object):
"""Convert image to grayscale.
Args:
num_output_channels (int): (1 or 3) number of channels desired for output image
Returns:
PIL Image: Grayscale version of the input.
- If ``num_output_channels == 1`` : returned image is single channel
- If ``num_output_channels == 3`` : returned image is 3 channel with r == g == b
"""
def __init__(self, num_output_channels=1):
self.num_output_channels = num_output_channels
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be converted to grayscale.
Returns:
PIL Image: Randomly grayscaled image.
"""
return F.to_grayscale(img, num_output_channels=self.num_output_channels)
def __repr__(self):
return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels)
[docs]class RandomGrayscale(object):
"""Randomly convert image to grayscale with a probability of p (default 0.1).
Args:
p (float): probability that image should be converted to grayscale.
Returns:
PIL Image: Grayscale version of the input image with probability p and unchanged
with probability (1-p).
- If input image is 1 channel: grayscale version is 1 channel
- If input image is 3 channel: grayscale version is 3 channel with r == g == b
"""
def __init__(self, p=0.1):
self.p = p
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be converted to grayscale.
Returns:
PIL Image: Randomly grayscaled image.
"""
num_output_channels = 1 if img.mode == 'L' else 3
if random.random() < self.p:
return F.to_grayscale(img, num_output_channels=num_output_channels)
return img
def __repr__(self):
return self.__class__.__name__ + '(p={0})'.format(self.p)
[docs]class RandomErasing(object):
""" Randomly selects a rectangle region in an image and erases its pixels.
'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/pdf/1708.04896.pdf
Args:
p: probability that the random erasing operation will be performed.
scale: range of proportion of erased area against input image.
ratio: range of aspect ratio of erased area.
value: erasing value. Default is 0. If a single int, it is used to
erase all pixels. If a tuple of length 3, it is used to erase
R, G, B channels respectively.
If a str of 'random', erasing each pixel with random values.
inplace: boolean to make this transform inplace. Default set to False.
Returns:
Erased Image.
# Examples:
>>> transform = transforms.Compose([
>>> transforms.RandomHorizontalFlip(),
>>> transforms.ToTensor(),
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
>>> transforms.RandomErasing(),
>>> ])
"""
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False):
assert isinstance(value, (numbers.Number, str, tuple, list))
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("range should be of kind (min, max)")
if scale[0] < 0 or scale[1] > 1:
raise ValueError("range of scale should be between 0 and 1")
if p < 0 or p > 1:
raise ValueError("range of random erasing probability should be between 0 and 1")
self.p = p
self.scale = scale
self.ratio = ratio
self.value = value
self.inplace = inplace
@staticmethod
def get_params(img, scale, ratio, value=0):
"""Get parameters for ``erase`` for a random erasing.
Args:
img (Tensor): Tensor image of size (C, H, W) to be erased.
scale: range of proportion of erased area against input image.
ratio: range of aspect ratio of erased area.
Returns:
tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing.
"""
img_c, img_h, img_w = img.shape
area = img_h * img_w
for _ in range(10):
erase_area = random.uniform(scale[0], scale[1]) * area
aspect_ratio = random.uniform(ratio[0], ratio[1])
h = int(round(math.sqrt(erase_area * aspect_ratio)))
w = int(round(math.sqrt(erase_area / aspect_ratio)))
if h < img_h and w < img_w:
i = random.randint(0, img_h - h)
j = random.randint(0, img_w - w)
if isinstance(value, numbers.Number):
v = value
elif isinstance(value, torch._six.string_classes):
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
elif isinstance(value, (list, tuple)):
v = torch.tensor(value, dtype=torch.float32).view(-1, 1, 1).expand(-1, h, w)
return i, j, h, w, v
# Return original image
return 0, 0, img_h, img_w, img
def __call__(self, img):
"""
Args:
img (Tensor): Tensor image of size (C, H, W) to be erased.
Returns:
img (Tensor): Erased Tensor image.
"""
if random.uniform(0, 1) < self.p:
x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=self.value)
return F.erase(img, x, y, h, w, v, self.inplace)
return img