Shortcuts

Source code for torchvision.transforms.transforms

import math
import numbers
import random
import warnings
from collections.abc import Sequence
from typing import Tuple, List, Optional

import torch
from torch import Tensor

try:
    import accimage
except ImportError:
    accimage = None

from ..utils import _log_api_usage_once
from . import functional as F
from .functional import InterpolationMode, _interpolation_modes_from_int

__all__ = [
    "Compose",
    "ToTensor",
    "PILToTensor",
    "ConvertImageDtype",
    "ToPILImage",
    "Normalize",
    "Resize",
    "CenterCrop",
    "Pad",
    "Lambda",
    "RandomApply",
    "RandomChoice",
    "RandomOrder",
    "RandomCrop",
    "RandomHorizontalFlip",
    "RandomVerticalFlip",
    "RandomResizedCrop",
    "FiveCrop",
    "TenCrop",
    "LinearTransformation",
    "ColorJitter",
    "RandomRotation",
    "RandomAffine",
    "Grayscale",
    "RandomGrayscale",
    "RandomPerspective",
    "RandomErasing",
    "GaussianBlur",
    "InterpolationMode",
    "RandomInvert",
    "RandomPosterize",
    "RandomSolarize",
    "RandomAdjustSharpness",
    "RandomAutocontrast",
    "RandomEqualize",
]


class Compose:
    """Composes several transforms together. This transform does not support torchscript.
    Please, see the note below.

    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.PILToTensor(),
        >>>     transforms.ConvertImageDtype(torch.float),
        >>> ])

    .. note::
        In order to script the transformations, please use ``torch.nn.Sequential`` as below.

        >>> transforms = torch.nn.Sequential(
        >>>     transforms.CenterCrop(10),
        >>>     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        >>> )
        >>> scripted_transforms = torch.jit.script(transforms)

        Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
        `lambda` functions or ``PIL.Image``.

    """

    def __init__(self, transforms):
        if not torch.jit.is_scripting() and not torch.jit.is_tracing():
            _log_api_usage_once(self)
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img

    def __repr__(self) -> str:
        format_string = self.__class__.__name__ + "("
        for t in self.transforms:
            format_string += "\n"
            format_string += f"    {t}"
        format_string += "\n)"
        return format_string


[docs]class ToTensor: """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript. Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) or if the numpy.ndarray has dtype = np.uint8 In the other cases, tensors are returned without scaling. .. note:: Because the input image is scaled to [0.0, 1.0], this transformation should not be used when transforming target image masks. See the `references`_ for implementing the transforms for image masks. .. _references: https://github.com/pytorch/vision/tree/main/references/segmentation """ def __init__(self) -> None: _log_api_usage_once(self) def __call__(self, pic): """ Args: pic (PIL Image or numpy.ndarray): Image to be converted to tensor. Returns: Tensor: Converted image. """ return F.to_tensor(pic) def __repr__(self) -> str: return f"{self.__class__.__name__}()"
class PILToTensor: """Convert a ``PIL Image`` to a tensor of the same type. This transform does not support torchscript. Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W). """ def __init__(self) -> None: _log_api_usage_once(self) def __call__(self, pic): """ .. note:: A deep copy of the underlying array is performed. Args: pic (PIL Image): Image to be converted to tensor. Returns: Tensor: Converted image. """ return F.pil_to_tensor(pic) def __repr__(self) -> str: return f"{self.__class__.__name__}()" class ConvertImageDtype(torch.nn.Module): """Convert a tensor image to the given ``dtype`` and scale the values accordingly This function does not support PIL Image. Args: dtype (torch.dtype): Desired data type of the output .. note:: When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly. If converted back and forth, this mismatch has no effect. Raises: RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range of the integer ``dtype``. """ def __init__(self, dtype: torch.dtype) -> None: super().__init__() _log_api_usage_once(self) self.dtype = dtype def forward(self, image): return F.convert_image_dtype(image, self.dtype)
[docs]class ToPILImage: """Convert a tensor or an ndarray to PIL Image. This transform does not support torchscript. Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape H x W x C to a PIL Image while preserving the value range. Args: mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). If ``mode`` is ``None`` (default) there are some assumptions made about the input data: - If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``. - If the input has 3 channels, the ``mode`` is assumed to be ``RGB``. - If the input has 2 channels, the ``mode`` is assumed to be ``LA``. - If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``, ``short``). .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes """ def __init__(self, mode=None): _log_api_usage_once(self) self.mode = mode def __call__(self, pic): """ Args: pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. Returns: PIL Image: Image converted to PIL Image. """ return F.to_pil_image(pic, self.mode) def __repr__(self) -> str: format_string = self.__class__.__name__ + "(" if self.mode is not None: format_string += f"mode={self.mode}" format_string += ")" return format_string
class Normalize(torch.nn.Module): """Normalize a tensor image with mean and standard deviation. This transform does not support PIL Image. Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n`` channels, this transform will normalize each channel of the input ``torch.*Tensor`` i.e., ``output[channel] = (input[channel] - mean[channel]) / std[channel]`` .. note:: This transform acts out of place, i.e., it does not mutate the input tensor. Args: mean (sequence): Sequence of means for each channel. std (sequence): Sequence of standard deviations for each channel. inplace(bool,optional): Bool to make this operation in-place. """ def __init__(self, mean, std, inplace=False): super().__init__() _log_api_usage_once(self) self.mean = mean self.std = std self.inplace = inplace def forward(self, tensor: Tensor) -> Tensor: """ Args: tensor (Tensor): Tensor image to be normalized. Returns: Tensor: Normalized Tensor image. """ return F.normalize(tensor, self.mean, self.std, self.inplace) def __repr__(self) -> str: return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"
[docs]class Resize(torch.nn.Module): """Resize the input image to the given size. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions .. warning:: The output image might be different depending on its type: when downsampling, the interpolation of PIL images and tensors is slightly different, because PIL applies antialiasing. This may lead to significant differences in the performance of a network. Therefore, it is preferable to train and serve a model with the same input types. See also below the ``antialias`` parameter, which can help making the output of PIL images and tensors closer. Args: size (sequence or int): Desired output size. If size is a sequence like (h, w), output size will be matched to this. If size is an int, smaller edge of the image will be matched to this number. i.e, if height > width, then image will be rescaled to (size * height / width, size). .. note:: In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``. interpolation (InterpolationMode): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted, but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum. max_size (int, optional): The maximum allowed for the longer edge of the resized image: if the longer edge of the image is greater than ``max_size`` after being resized according to ``size``, then the image is resized again so that the longer edge is equal to ``max_size``. As a result, ``size`` might be overruled, i.e the smaller edge may be shorter than ``size``. This is only supported if ``size`` is an int (or a sequence of length 1 in torchscript mode). antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias is always used. If ``img`` is Tensor, the flag is False by default and can be set to True for ``InterpolationMode.BILINEAR`` only mode. This can help making the output for PIL images and tensors closer. .. warning:: There is no autodiff support for ``antialias=True`` option with input ``img`` as Tensor. """ def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias=None): super().__init__() _log_api_usage_once(self) if not isinstance(size, (int, Sequence)): raise TypeError(f"Size should be int or sequence. Got {type(size)}") if isinstance(size, Sequence) and len(size) not in (1, 2): raise ValueError("If size is a sequence, it should have 1 or 2 values") self.size = size self.max_size = max_size # Backward compatibility with integer value if isinstance(interpolation, int): warnings.warn( "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. " "Please use InterpolationMode enum." ) interpolation = _interpolation_modes_from_int(interpolation) self.interpolation = interpolation self.antialias = antialias
[docs] def forward(self, img): """ Args: img (PIL Image or Tensor): Image to be scaled. Returns: PIL Image or Tensor: Rescaled image. """ return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
def __repr__(self) -> str: detail = f"(size={self.size}, interpolation={self.interpolation.value}, max_size={self.max_size}, antialias={self.antialias})" return f"{self.__class__.__name__}{detail}"
class CenterCrop(torch.nn.Module): """Crops the given image at the center. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. Args: size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). """ def __init__(self, size): super().__init__() _log_api_usage_once(self) self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") def forward(self, img): """ Args: img (PIL Image or Tensor): Image to be cropped. Returns: PIL Image or Tensor: Cropped image. """ return F.center_crop(img, self.size) def __repr__(self) -> str: return f"{self.__class__.__name__}(size={self.size})" class Pad(torch.nn.Module): """Pad the given image on all sides with the given "pad" value. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric, at most 3 leading dimensions for mode edge, and an arbitrary number of leading dimensions for mode constant Args: padding (int or sequence): Padding on each border. If a single int is provided this is used to pad all borders. If sequence of length 2 is provided this is the padding on left/right and top/bottom respectively. If a sequence of length 4 is provided this is the padding for the left, top, right and bottom borders respectively. .. note:: In torchscript mode padding as single int is not supported, use a sequence of length 1: ``[padding, ]``. fill (number or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. This value is only used when the padding_mode is constant. Only number is supported for torch Tensor. Only int or tuple value is supported for PIL Image. padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. - constant: pads with a constant value, this value is specified with fill - edge: pads with the last value at the edge of the image. If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2 - reflect: pads with reflection of image without repeating the last value on the edge. For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode will result in [3, 2, 1, 2, 3, 4, 3, 2] - symmetric: pads with reflection of image repeating the last value on the edge. For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode will result in [2, 1, 1, 2, 3, 4, 4, 3] """ def __init__(self, padding, fill=0, padding_mode="constant"): super().__init__() _log_api_usage_once(self) if not isinstance(padding, (numbers.Number, tuple, list)): raise TypeError("Got inappropriate padding arg") if not isinstance(fill, (numbers.Number, tuple, list)): raise TypeError("Got inappropriate fill arg") if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]: raise ValueError( f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple" ) self.padding = padding self.fill = fill self.padding_mode = padding_mode def forward(self, img): """ Args: img (PIL Image or Tensor): Image to be padded. Returns: PIL Image or Tensor: Padded image. """ return F.pad(img, self.padding, self.fill, self.padding_mode) def __repr__(self) -> str: return f"{self.__class__.__name__}(padding={self.padding}, fill={self.fill}, padding_mode={self.padding_mode})" class Lambda: """Apply a user-defined lambda as a transform. This transform does not support torchscript. Args: lambd (function): Lambda/function to be used for transform. """ def __init__(self, lambd): _log_api_usage_once(self) if not callable(lambd): raise TypeError(f"Argument lambd should be callable, got {repr(type(lambd).__name__)}") self.lambd = lambd def __call__(self, img): return self.lambd(img) def __repr__(self) -> str: return f"{self.__class__.__name__}()" class RandomTransforms: """Base class for a list of transformations with randomness Args: transforms (sequence): list of transformations """ def __init__(self, transforms): _log_api_usage_once(self) if not isinstance(transforms, Sequence): raise TypeError("Argument transforms should be a sequence") self.transforms = transforms def __call__(self, *args, **kwargs): raise NotImplementedError() def __repr__(self) -> str: format_string = self.__class__.__name__ + "(" for t in self.transforms: format_string += "\n" format_string += f" {t}" format_string += "\n)" return format_string class RandomApply(torch.nn.Module): """Apply randomly a list of transformations with a given probability. .. note:: In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of transforms as shown below: >>> transforms = transforms.RandomApply(torch.nn.ModuleList([ >>> transforms.ColorJitter(), >>> ]), p=0.3) >>> scripted_transforms = torch.jit.script(transforms) Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require `lambda` functions or ``PIL.Image``. Args: transforms (sequence or torch.nn.Module): list of transformations p (float): probability """ def __init__(self, transforms, p=0.5): super().__init__() _log_api_usage_once(self) self.transforms = transforms self.p = p def forward(self, img): if self.p < torch.rand(1): return img for t in self.transforms: img = t(img) return img def __repr__(self) -> str: format_string = self.__class__.__name__ + "(" format_string += f"\n p={self.p}" for t in self.transforms: format_string += "\n" format_string += f" {t}" format_string += "\n)" return format_string class RandomOrder(RandomTransforms): """Apply a list of transformations in a random order. This transform does not support torchscript.""" def __call__(self, img): order = list(range(len(self.transforms))) random.shuffle(order) for i in order: img = self.transforms[i](img) return img class RandomChoice(RandomTransforms): """Apply single transformation randomly picked from a list. This transform does not support torchscript.""" def __init__(self, transforms, p=None): super().__init__(transforms) if p is not None and not isinstance(p, Sequence): raise TypeError("Argument p should be a sequence") self.p = p def __call__(self, *args): t = random.choices(self.transforms, weights=self.p)[0] return t(*args) def __repr__(self) -> str: return f"{super().__repr__()}(p={self.p})" class RandomCrop(torch.nn.Module): """Crop the given image at a random location. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions, but if non-constant padding is used, the input is expected to have at most 2 leading dimensions Args: size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). padding (int or sequence, optional): Optional padding on each border of the image. Default is None. If a single int is provided this is used to pad all borders. If sequence of length 2 is provided this is the padding on left/right and top/bottom respectively. If a sequence of length 4 is provided this is the padding for the left, top, right and bottom borders respectively. .. note:: In torchscript mode padding as single int is not supported, use a sequence of length 1: ``[padding, ]``. pad_if_needed (boolean): It will pad the image if smaller than the desired size to avoid raising an exception. Since cropping is done after padding, the padding seems to be done at a random offset. fill (number or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. This value is only used when the padding_mode is constant. Only number is supported for torch Tensor. Only int or tuple value is supported for PIL Image. padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. - constant: pads with a constant value, this value is specified with fill - edge: pads with the last value at the edge of the image. If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2 - reflect: pads with reflection of image without repeating the last value on the edge. For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode will result in [3, 2, 1, 2, 3, 4, 3, 2] - symmetric: pads with reflection of image repeating the last value on the edge. For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode will result in [2, 1, 1, 2, 3, 4, 4, 3] """ @staticmethod def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]: """Get parameters for ``crop`` for a random crop. Args: img (PIL Image or Tensor): Image to be cropped. output_size (tuple): Expected output size of the crop. Returns: tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. """ _, h, w = F.get_dimensions(img) th, tw = output_size if h + 1 < th or w + 1 < tw: raise ValueError(f"Required crop size {(th, tw)} is larger then input image size {(h, w)}") if w == tw and h == th: return 0, 0, h, w i = torch.randint(0, h - th + 1, size=(1,)).item() j = torch.randint(0, w - tw + 1, size=(1,)).item() return i, j, th, tw def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"): super().__init__() _log_api_usage_once(self) self.size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")) self.padding = padding self.pad_if_needed = pad_if_needed self.fill = fill self.padding_mode = padding_mode def forward(self, img): """ Args: img (PIL Image or Tensor): Image to be cropped. Returns: PIL Image or Tensor: Cropped image. """ if self.padding is not None: img = F.pad(img, self.padding, self.fill, self.padding_mode) _, height, width = F.get_dimensions(img) # pad the width if needed if self.pad_if_needed and width < self.size[1]: padding = [self.size[1] - width, 0] img = F.pad(img, padding, self.fill, self.padding_mode) # pad the height if needed if self.pad_if_needed and height < self.size[0]: padding = [0, self.size[0] - height] img = F.pad(img, padding, self.fill, self.padding_mode) i, j, h, w = self.get_params(img, self.size) return F.crop(img, i, j, h, w) def __repr__(self) -> str: return f"{self.__class__.__name__}(size={self.size}, padding={self.padding})" class RandomHorizontalFlip(torch.nn.Module): """Horizontally flip the given image randomly with a given probability. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions Args: p (float): probability of the image being flipped. Default value is 0.5 """ def __init__(self, p=0.5): super().__init__() _log_api_usage_once(self) self.p = p def forward(self, img): """ Args: img (PIL Image or Tensor): Image to be flipped. Returns: PIL Image or Tensor: Randomly flipped image. """ if torch.rand(1) < self.p: return F.hflip(img) return img def __repr__(self) -> str: return f"{self.__class__.__name__}(p={self.p})"
[docs]class RandomVerticalFlip(torch.nn.Module): """Vertically flip the given image randomly with a given probability. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions Args: p (float): probability of the image being flipped. Default value is 0.5 """ def __init__(self, p=0.5): super().__init__() _log_api_usage_once(self) self.p = p
[docs] def forward(self, img): """ Args: img (PIL Image or Tensor): Image to be flipped. Returns: PIL Image or Tensor: Randomly flipped image. """ if torch.rand(1) < self.p: return F.vflip(img) return img
def __repr__(self) -> str: return f"{self.__class__.__name__}(p={self.p})"
class RandomPerspective(torch.nn.Module): """Performs a random perspective transformation of the given image with a given probability. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. Args: distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1. Default is 0.5. p (float): probability of the image being transformed. Default is 0.5. interpolation (InterpolationMode): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted, but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum. fill (sequence or number): Pixel fill value for the area outside the transformed image. Default is ``0``. If given a number, the value is used for all bands respectively. """ def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationMode.BILINEAR, fill=0): super().__init__() _log_api_usage_once(self) self.p = p # Backward compatibility with integer value if isinstance(interpolation, int): warnings.warn( "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. " "Please use InterpolationMode enum." ) interpolation = _interpolation_modes_from_int(interpolation) self.interpolation = interpolation self.distortion_scale = distortion_scale if fill is None: fill = 0 elif not isinstance(fill, (Sequence, numbers.Number)): raise TypeError("Fill should be either a sequence or a number.") self.fill = fill def forward(self, img): """ Args: img (PIL Image or Tensor): Image to be Perspectively transformed. Returns: PIL Image or Tensor: Randomly transformed image. """ fill = self.fill channels, height, width = F.get_dimensions(img) if isinstance(img, Tensor): if isinstance(fill, (int, float)): fill = [float(fill)] * channels else: fill = [float(f) for f in fill] if torch.rand(1) < self.p: startpoints, endpoints = self.get_params(width, height, self.distortion_scale) return F.perspective(img, startpoints, endpoints, self.interpolation, fill) return img @staticmethod def get_params(width: int, height: int, distortion_scale: float) -> Tuple[List[List[int]], List[List[int]]]: """Get parameters for ``perspective`` for a random perspective transform. Args: width (int): width of the image. height (int): height of the image. distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1. Returns: List containing [top-left, top-right, bottom-right, bottom-left] of the original image, List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image. """ half_height = height // 2 half_width = width // 2 topleft = [ int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()), int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()), ] topright = [ int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()), int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()), ] botright = [ int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()), int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()), ] botleft = [ int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()), int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()), ] startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]] endpoints = [topleft, topright, botright, botleft] return startpoints, endpoints def __repr__(self) -> str: return f"{self.__class__.__name__}(p={self.p})"
[docs]class RandomResizedCrop(torch.nn.Module): """Crop a random portion of image and resize it to a given size. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions A crop of the original image is made: the crop has a random area (H * W) and a random aspect ratio. This crop is finally resized to the given size. This is popularly used to train the Inception networks. Args: size (int or sequence): expected output size of the crop, for each edge. If size is an int instead of sequence like (h, w), a square output size ``(size, size)`` is made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). .. note:: In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``. scale (tuple of float): Specifies the lower and upper bounds for the random area of the crop, before resizing. The scale is defined with respect to the area of the original image. ratio (tuple of float): lower and upper bounds for the random aspect ratio of the crop, before resizing. interpolation (InterpolationMode): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted, but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum. """ def __init__(self, size, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0), interpolation=InterpolationMode.BILINEAR): super().__init__() _log_api_usage_once(self) self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") if not isinstance(scale, Sequence): raise TypeError("Scale should be a sequence") if not isinstance(ratio, Sequence): raise TypeError("Ratio should be a sequence") if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): warnings.warn("Scale and ratio should be of kind (min, max)") # Backward compatibility with integer value if isinstance(interpolation, int): warnings.warn( "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. " "Please use InterpolationMode enum." ) interpolation = _interpolation_modes_from_int(interpolation) self.interpolation = interpolation self.scale = scale self.ratio = ratio
[docs] @staticmethod def get_params(img: Tensor, scale: List[float], ratio: List[float]) -> Tuple[int, int, int, int]: """Get parameters for ``crop`` for a random sized crop. Args: img (PIL Image or Tensor): Input image. scale (list): range of scale of the origin size cropped ratio (list): range of aspect ratio of the origin aspect ratio cropped Returns: tuple: params (i, j, h, w) to be passed to ``crop`` for a random sized crop. """ _, height, width = F.get_dimensions(img) area = height * width log_ratio = torch.log(torch.tensor(ratio)) for _ in range(10): target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item() w = int(round(math.sqrt(target_area * aspect_ratio))) h = int(round(math.sqrt(target_area / aspect_ratio))) if 0 < w <= width and 0 < h <= height: i = torch.randint(0, height - h + 1, size=(1,)).item() j = torch.randint(0, width - w + 1, size=(1,)).item() return i, j, h, w # Fallback to central crop in_ratio = float(width) / float(height) if in_ratio < min(ratio): w = width h = int(round(w / min(ratio))) elif in_ratio > max(ratio): h = height w = int(round(h * max(ratio))) else: # whole image w = width h = height i = (height - h) // 2 j = (width - w) // 2 return i, j, h, w
[docs] def forward(self, img): """ Args: img (PIL Image or Tensor): Image to be cropped and resized. Returns: PIL Image or Tensor: Randomly cropped and resized image. """ i, j, h, w = self.get_params(img, self.scale, self.ratio) return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
def __repr__(self) -> str: interpolate_str = self.interpolation.value format_string = self.__class__.__name__ + f"(size={self.size}" format_string += f", scale={tuple(round(s, 4) for s in self.scale)}" format_string += f", ratio={tuple(round(r, 4) for r in self.ratio)}" format_string += f", interpolation={interpolate_str})" return format_string
class FiveCrop(torch.nn.Module): """Crop the given image into four corners and the central crop. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions .. Note:: This transform returns a tuple of images and there may be a mismatch in the number of inputs and targets your Dataset returns. See below for an example of how to deal with this. Args: size (sequence or int): Desired output size of the crop. If size is an ``int`` instead of sequence like (h, w), a square crop of size (size, size) is made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). Example: >>> transform = Compose([ >>> FiveCrop(size), # this is a list of PIL Images >>> Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor >>> ]) >>> #In your test loop you can do the following: >>> input, target = batch # input is a 5d tensor, target is 2d >>> bs, ncrops, c, h, w = input.size() >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops """ def __init__(self, size): super().__init__() _log_api_usage_once(self) self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") def forward(self, img): """ Args: img (PIL Image or Tensor): Image to be cropped. Returns: tuple of 5 images. Image can be PIL Image or Tensor """ return F.five_crop(img, self.size) def __repr__(self) -> str: return f"{self.__class__.__name__}(size={self.size})"
[docs]class TenCrop(torch.nn.Module): """Crop the given image into four corners and the central crop plus the flipped version of these (horizontal flipping is used by default). If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions .. Note:: This transform returns a tuple of images and there may be a mismatch in the number of inputs and targets your Dataset returns. See below for an example of how to deal with this. Args: size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). vertical_flip (bool): Use vertical flipping instead of horizontal Example: >>> transform = Compose([ >>> TenCrop(size), # this is a list of PIL Images >>> Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor >>> ]) >>> #In your test loop you can do the following: >>> input, target = batch # input is a 5d tensor, target is 2d >>> bs, ncrops, c, h, w = input.size() >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops """ def __init__(self, size, vertical_flip=False): super().__init__() _log_api_usage_once(self) self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.vertical_flip = vertical_flip
[docs] def forward(self, img): """ Args: img (PIL Image or Tensor): Image to be cropped. Returns: tuple of 10 images. Image can be PIL Image or Tensor """ return F.ten_crop(img, self.size, self.vertical_flip)
def __repr__(self) -> str: return f"{self.__class__.__name__}(size={self.size}, vertical_flip={self.vertical_flip})"
class LinearTransformation(torch.nn.Module): """Transform a tensor image with a square transformation matrix and a mean_vector computed offline. This transform does not support PIL Image. Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and subtract mean_vector from it which is then followed by computing the dot product with the transformation matrix and then reshaping the tensor to its original shape. Applications: whitening transformation: Suppose X is a column vector zero-centered data. Then compute the data covariance matrix [D x D] with torch.mm(X.t(), X), perform SVD on this matrix and pass it as transformation_matrix. Args: transformation_matrix (Tensor): tensor [D x D], D = C x H x W mean_vector (Tensor): tensor [D], D = C x H x W """ def __init__(self, transformation_matrix, mean_vector): super().__init__() _log_api_usage_once(self) if transformation_matrix.size(0) != transformation_matrix.size(1): raise ValueError( "transformation_matrix should be square. Got " f"{tuple(transformation_matrix.size())} rectangular matrix." ) if mean_vector.size(0) != transformation_matrix.size(0): raise ValueError( f"mean_vector should have the same length {mean_vector.size(0)}" f" as any one of the dimensions of the transformation_matrix [{tuple(transformation_matrix.size())}]" ) if transformation_matrix.device != mean_vector.device: raise ValueError( f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}" ) self.transformation_matrix = transformation_matrix self.mean_vector = mean_vector def forward(self, tensor: Tensor) -> Tensor: """ Args: tensor (Tensor): Tensor image to be whitened. Returns: Tensor: Transformed image. """ shape = tensor.shape n = shape[-3] * shape[-2] * shape[-1] if n != self.transformation_matrix.shape[0]: raise ValueError( "Input tensor and transformation matrix have incompatible shape." + f"[{shape[-3]} x {shape[-2]} x {shape[-1]}] != " + f"{self.transformation_matrix.shape[0]}" ) if tensor.device.type != self.mean_vector.device.type: raise ValueError( "Input tensor should be on the same device as transformation matrix and mean vector. " f"Got {tensor.device} vs {self.mean_vector.device}" ) flat_tensor = tensor.view(-1, n) - self.mean_vector transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix) tensor = transformed_tensor.view(shape) return tensor def __repr__(self) -> str: s = ( f"{self.__class__.__name__}(transformation_matrix=" f"{self.transformation_matrix.tolist()}" f", mean_vector={self.mean_vector.tolist()})" ) return s class ColorJitter(torch.nn.Module): """Randomly change the brightness, contrast, saturation and hue of an image. If the image is torch Tensor, it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. If img is PIL Image, mode "1", "I", "F" and modes with transparency (alpha channel) are not supported. Args: brightness (float or tuple of float (min, max)): How much to jitter brightness. brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] or the given [min, max]. Should be non negative numbers. contrast (float or tuple of float (min, max)): How much to jitter contrast. contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] or the given [min, max]. Should be non negative numbers. saturation (float or tuple of float (min, max)): How much to jitter saturation. saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] or the given [min, max]. Should be non negative numbers. hue (float or tuple of float (min, max)): How much to jitter hue. hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space; thus it does not work if you normalize your image to an interval with negative values, or use an interpolation that generates negative values before using this function. """ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): super().__init__() _log_api_usage_once(self) self.brightness = self._check_input(brightness, "brightness") self.contrast = self._check_input(contrast, "contrast") self.saturation = self._check_input(saturation, "saturation") self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) @torch.jit.unused def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True): if isinstance(value, numbers.Number): if value < 0: raise ValueError(f"If {name} is a single number, it must be non negative.") value = [center - float(value), center + float(value)] if clip_first_on_zero: value[0] = max(value[0], 0.0) elif isinstance(value, (tuple, list)) and len(value) == 2: if not bound[0] <= value[0] <= value[1] <= bound[1]: raise ValueError(f"{name} values should be between {bound}") else: raise TypeError(f"{name} should be a single number or a list/tuple with length 2.") # if value is 0 or (1., 1.) for brightness/contrast/saturation # or (0., 0.) for hue, do nothing if value[0] == value[1] == center: value = None return value @staticmethod def get_params( brightness: Optional[List[float]], contrast: Optional[List[float]], saturation: Optional[List[float]], hue: Optional[List[float]], ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]: """Get the parameters for the randomized transform to be applied on image. Args: brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen uniformly. Pass None to turn off the transformation. contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen uniformly. Pass None to turn off the transformation. saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen uniformly. Pass None to turn off the transformation. hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly. Pass None to turn off the transformation. Returns: tuple: The parameters used to apply the randomized transform along with their random order. """ fn_idx = torch.randperm(4) b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1])) c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1])) s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1])) h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1])) return fn_idx, b, c, s, h def forward(self, img): """ Args: img (PIL Image or Tensor): Input image. Returns: PIL Image or Tensor: Color jittered image. """ fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params( self.brightness, self.contrast, self.saturation, self.hue ) for fn_id in fn_idx: if fn_id == 0 and brightness_factor is not None: img = F.adjust_brightness(img, brightness_factor) elif fn_id == 1 and contrast_factor is not None: img = F.adjust_contrast(img, contrast_factor) elif fn_id == 2 and saturation_factor is not None: img = F.adjust_saturation(img, saturation_factor) elif fn_id == 3 and hue_factor is not None: img = F.adjust_hue(img, hue_factor) return img def __repr__(self) -> str: s = ( f"{self.__class__.__name__}(" f"brightness={self.brightness}" f", contrast={self.contrast}" f", saturation={self.saturation}" f", hue={self.hue})" ) return s
[docs]class RandomRotation(torch.nn.Module): """Rotate the image by angle. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. Args: degrees (sequence or number): Range of degrees to select from. If degrees is a number instead of sequence like (min, max), the range of degrees will be (-degrees, +degrees). interpolation (InterpolationMode): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted, but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum. expand (bool, optional): Optional expansion flag. If true, expands the output to make it large enough to hold the entire rotated image. If false or omitted, make the output image the same size as the input image. Note that the expand flag assumes rotation around the center and no translation. center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner. Default is the center of the image. fill (sequence or number): Pixel fill value for the area outside the rotated image. Default is ``0``. If given a number, the value is used for all bands respectively. resample (int, optional): .. warning:: This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``interpolation`` instead. .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters """ def __init__( self, degrees, interpolation=InterpolationMode.NEAREST, expand=False, center=None, fill=0, resample=None ): super().__init__() _log_api_usage_once(self) if resample is not None: warnings.warn( "The parameter 'resample' is deprecated since 0.12 and will be removed 0.14. " "Please use 'interpolation' instead." ) interpolation = _interpolation_modes_from_int(resample) # Backward compatibility with integer value if isinstance(interpolation, int): warnings.warn( "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. " "Please use InterpolationMode enum." ) interpolation = _interpolation_modes_from_int(interpolation) self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) if center is not None: _check_sequence_input(center, "center", req_sizes=(2,)) self.center = center self.resample = self.interpolation = interpolation self.expand = expand if fill is None: fill = 0 elif not isinstance(fill, (Sequence, numbers.Number)): raise TypeError("Fill should be either a sequence or a number.") self.fill = fill
[docs] @staticmethod def get_params(degrees: List[float]) -> float: """Get parameters for ``rotate`` for a random rotation. Returns: float: angle parameter to be passed to ``rotate`` for random rotation. """ angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item()) return angle
[docs] def forward(self, img): """ Args: img (PIL Image or Tensor): Image to be rotated. Returns: PIL Image or Tensor: Rotated image. """ fill = self.fill channels, _, _ = F.get_dimensions(img) if isinstance(img, Tensor): if isinstance(fill, (int, float)): fill = [float(fill)] * channels else: fill = [float(f) for f in fill] angle = self.get_params(self.degrees) return F.rotate(img, angle, self.resample, self.expand, self.center, fill)
def __repr__(self) -> str: interpolate_str = self.interpolation.value format_string = self.__class__.__name__ + f"(degrees={self.degrees}" format_string += f", interpolation={interpolate_str}" format_string += f", expand={self.expand}" if self.center is not None: format_string += f", center={self.center}" if self.fill is not None: format_string += f", fill={self.fill}" format_string += ")" return format_string
class RandomAffine(torch.nn.Module): """Random affine transformation of the image keeping center invariant. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. Args: degrees (sequence or number): Range of degrees to select from. If degrees is a number instead of sequence like (min, max), the range of degrees will be (-degrees, +degrees). Set to 0 to deactivate rotations. translate (tuple, optional): tuple of maximum absolute fraction for horizontal and vertical translations. For example translate=(a, b), then horizontal shift is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default. scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is randomly sampled from the range a <= scale <= b. Will keep original scale by default. shear (sequence or number, optional): Range of degrees to select from. If shear is a number, a shear parallel to the x axis in the range (-shear, +shear) will be applied. Else if shear is a sequence of 2 values a shear parallel to the x axis in the range (shear[0], shear[1]) will be applied. Else if shear is a sequence of 4 values, a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied. Will not apply shear by default. interpolation (InterpolationMode): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted, but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum. fill (sequence or number): Pixel fill value for the area outside the transformed image. Default is ``0``. If given a number, the value is used for all bands respectively. fillcolor (sequence or number, optional): .. warning:: This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``fill`` instead. resample (int, optional): .. warning:: This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``interpolation`` instead. center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner. Default is the center of the image. .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters """ def __init__( self, degrees, translate=None, scale=None, shear=None, interpolation=InterpolationMode.NEAREST, fill=0, fillcolor=None, resample=None, center=None, ): super().__init__() _log_api_usage_once(self) if resample is not None: warnings.warn( "The parameter 'resample' is deprecated since 0.12 and will be removed in 0.14. " "Please use 'interpolation' instead." ) interpolation = _interpolation_modes_from_int(resample) # Backward compatibility with integer value if isinstance(interpolation, int): warnings.warn( "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. " "Please use InterpolationMode enum." ) interpolation = _interpolation_modes_from_int(interpolation) if fillcolor is not None: warnings.warn( "The parameter 'fillcolor' is deprecated since 0.12 and will be removed in 0.14. " "Please use 'fill' instead." ) fill = fillcolor self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) if translate is not None: _check_sequence_input(translate, "translate", req_sizes=(2,)) for t in translate: if not (0.0 <= t <= 1.0): raise ValueError("translation values should be between 0 and 1") self.translate = translate if scale is not None: _check_sequence_input(scale, "scale", req_sizes=(2,)) for s in scale: if s <= 0: raise ValueError("scale values should be positive") self.scale = scale if shear is not None: self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4)) else: self.shear = shear self.resample = self.interpolation = interpolation if fill is None: fill = 0 elif not isinstance(fill, (Sequence, numbers.Number)): raise TypeError("Fill should be either a sequence or a number.") self.fillcolor = self.fill = fill if center is not None: _check_sequence_input(center, "center", req_sizes=(2,)) self.center = center @staticmethod def get_params( degrees: List[float], translate: Optional[List[float]], scale_ranges: Optional[List[float]], shears: Optional[List[float]], img_size: List[int], ) -> Tuple[float, Tuple[int, int], float, Tuple[float, float]]: """Get parameters for affine transformation Returns: params to be passed to the affine transformation """ angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item()) if translate is not None: max_dx = float(translate[0] * img_size[0]) max_dy = float(translate[1] * img_size[1]) tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item())) ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item())) translations = (tx, ty) else: translations = (0, 0) if scale_ranges is not None: scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item()) else: scale = 1.0 shear_x = shear_y = 0.0 if shears is not None: shear_x = float(torch.empty(1).uniform_(shears[0], shears[1]).item()) if len(shears) == 4: shear_y = float(torch.empty(1).uniform_(shears[2], shears[3]).item()) shear = (shear_x, shear_y) return angle, translations, scale, shear def forward(self, img): """ img (PIL Image or Tensor): Image to be transformed. Returns: PIL Image or Tensor: Affine transformed image. """ fill = self.fill channels, height, width = F.get_dimensions(img) if isinstance(img, Tensor): if isinstance(fill, (int, float)): fill = [float(fill)] * channels else: fill = [float(f) for f in fill] img_size = [width, height] # flip for keeping BC on get_params call ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size) return F.affine(img, *ret, interpolation=self.interpolation, fill=fill, center=self.center) def __repr__(self) -> str: s = f"{self.__class__.__name__}(degrees={self.degrees}" s += f", translate={self.translate}" if self.translate is not None else "" s += f", scale={self.scale}" if self.scale is not None else "" s += f", shear={self.shear}" if self.shear is not None else "" s += f", interpolation={self.interpolation.value}" if self.interpolation != InterpolationMode.NEAREST else "" s += f", fill={self.fill}" if self.fill != 0 else "" s += f", center={self.center}" if self.center is not None else "" s += ")" return s class Grayscale(torch.nn.Module): """Convert image to grayscale. If the image is torch Tensor, it is expected to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions Args: num_output_channels (int): (1 or 3) number of channels desired for output image Returns: PIL Image: Grayscale version of the input. - If ``num_output_channels == 1`` : returned image is single channel - If ``num_output_channels == 3`` : returned image is 3 channel with r == g == b """ def __init__(self, num_output_channels=1): super().__init__() _log_api_usage_once(self) self.num_output_channels = num_output_channels def forward(self, img): """ Args: img (PIL Image or Tensor): Image to be converted to grayscale. Returns: PIL Image or Tensor: Grayscaled image. """ return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels) def __repr__(self) -> str: return f"{self.__class__.__name__}(num_output_channels={self.num_output_channels})" class RandomGrayscale(torch.nn.Module): """Randomly convert image to grayscale with a probability of p (default 0.1). If the image is torch Tensor, it is expected to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions Args: p (float): probability that image should be converted to grayscale. Returns: PIL Image or Tensor: Grayscale version of the input image with probability p and unchanged with probability (1-p). - If input image is 1 channel: grayscale version is 1 channel - If input image is 3 channel: grayscale version is 3 channel with r == g == b """ def __init__(self, p=0.1): super().__init__() _log_api_usage_once(self) self.p = p def forward(self, img): """ Args: img (PIL Image or Tensor): Image to be converted to grayscale. Returns: PIL Image or Tensor: Randomly grayscaled image. """ num_output_channels, _, _ = F.get_dimensions(img) if torch.rand(1) < self.p: return F.rgb_to_grayscale(img, num_output_channels=num_output_channels) return img def __repr__(self) -> str: return f"{self.__class__.__name__}(p={self.p})" class RandomErasing(torch.nn.Module): """Randomly selects a rectangle region in an torch Tensor image and erases its pixels. This transform does not support PIL Image. 'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896 Args: p: probability that the random erasing operation will be performed. scale: range of proportion of erased area against input image. ratio: range of aspect ratio of erased area. value: erasing value. Default is 0. If a single int, it is used to erase all pixels. If a tuple of length 3, it is used to erase R, G, B channels respectively. If a str of 'random', erasing each pixel with random values. inplace: boolean to make this transform inplace. Default set to False. Returns: Erased Image. Example: >>> transform = transforms.Compose([ >>> transforms.RandomHorizontalFlip(), >>> transforms.PILToTensor(), >>> transforms.ConvertImageDtype(torch.float), >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), >>> transforms.RandomErasing(), >>> ]) """ def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False): super().__init__() _log_api_usage_once(self) if not isinstance(value, (numbers.Number, str, tuple, list)): raise TypeError("Argument value should be either a number or str or a sequence") if isinstance(value, str) and value != "random": raise ValueError("If value is str, it should be 'random'") if not isinstance(scale, (tuple, list)): raise TypeError("Scale should be a sequence") if not isinstance(ratio, (tuple, list)): raise TypeError("Ratio should be a sequence") if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): warnings.warn("Scale and ratio should be of kind (min, max)") if scale[0] < 0 or scale[1] > 1: raise ValueError("Scale should be between 0 and 1") if p < 0 or p > 1: raise ValueError("Random erasing probability should be between 0 and 1") self.p = p self.scale = scale self.ratio = ratio self.value = value self.inplace = inplace @staticmethod def get_params( img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None ) -> Tuple[int, int, int, int, Tensor]: """Get parameters for ``erase`` for a random erasing. Args: img (Tensor): Tensor image to be erased. scale (sequence): range of proportion of erased area against input image. ratio (sequence): range of aspect ratio of erased area. value (list, optional): erasing value. If None, it is interpreted as "random" (erasing each pixel with random values). If ``len(value)`` is 1, it is interpreted as a number, i.e. ``value[0]``. Returns: tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing. """ img_c, img_h, img_w = img.shape[-3], img.shape[-2], img.shape[-1] area = img_h * img_w log_ratio = torch.log(torch.tensor(ratio)) for _ in range(10): erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item() h = int(round(math.sqrt(erase_area * aspect_ratio))) w = int(round(math.sqrt(erase_area / aspect_ratio))) if not (h < img_h and w < img_w): continue if value is None: v = torch.empty([img_c, h, w], dtype=torch.float32).normal_() else: v = torch.tensor(value)[:, None, None] i = torch.randint(0, img_h - h + 1, size=(1,)).item() j = torch.randint(0, img_w - w + 1, size=(1,)).item() return i, j, h, w, v # Return original image return 0, 0, img_h, img_w, img def forward(self, img): """ Args: img (Tensor): Tensor image to be erased. Returns: img (Tensor): Erased Tensor image. """ if torch.rand(1) < self.p: # cast self.value to script acceptable type if isinstance(self.value, (int, float)): value = [self.value] elif isinstance(self.value, str): value = None elif isinstance(self.value, tuple): value = list(self.value) else: value = self.value if value is not None and not (len(value) in (1, img.shape[-3])): raise ValueError( "If value is a sequence, it should have either a single value or " f"{img.shape[-3]} (number of input channels)" ) x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value) return F.erase(img, x, y, h, w, v, self.inplace) return img def __repr__(self) -> str: s = ( f"{self.__class__.__name__}" f"(p={self.p}, " f"scale={self.scale}, " f"ratio={self.ratio}, " f"value={self.value}, " f"inplace={self.inplace})" ) return s class GaussianBlur(torch.nn.Module): """Blurs image with randomly chosen Gaussian blur. If the image is torch Tensor, it is expected to have [..., C, H, W] shape, where ... means an arbitrary number of leading dimensions. Args: kernel_size (int or sequence): Size of the Gaussian kernel. sigma (float or tuple of float (min, max)): Standard deviation to be used for creating kernel to perform blurring. If float, sigma is fixed. If it is tuple of float (min, max), sigma is chosen uniformly at random to lie in the given range. Returns: PIL Image or Tensor: Gaussian blurred version of the input image. """ def __init__(self, kernel_size, sigma=(0.1, 2.0)): super().__init__() _log_api_usage_once(self) self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers") for ks in self.kernel_size: if ks <= 0 or ks % 2 == 0: raise ValueError("Kernel size value should be an odd and positive number.") if isinstance(sigma, numbers.Number): if sigma <= 0: raise ValueError("If sigma is a single number, it must be positive.") sigma = (sigma, sigma) elif isinstance(sigma, Sequence) and len(sigma) == 2: if not 0.0 < sigma[0] <= sigma[1]: raise ValueError("sigma values should be positive and of the form (min, max).") else: raise ValueError("sigma should be a single number or a list/tuple with length 2.") self.sigma = sigma @staticmethod def get_params(sigma_min: float, sigma_max: float) -> float: """Choose sigma for random gaussian blurring. Args: sigma_min (float): Minimum standard deviation that can be chosen for blurring kernel. sigma_max (float): Maximum standard deviation that can be chosen for blurring kernel. Returns: float: Standard deviation to be passed to calculate kernel for gaussian blurring. """ return torch.empty(1).uniform_(sigma_min, sigma_max).item() def forward(self, img: Tensor) -> Tensor: """ Args: img (PIL Image or Tensor): image to be blurred. Returns: PIL Image or Tensor: Gaussian blurred image """ sigma = self.get_params(self.sigma[0], self.sigma[1]) return F.gaussian_blur(img, self.kernel_size, [sigma, sigma]) def __repr__(self) -> str: s = f"{self.__class__.__name__}(kernel_size={self.kernel_size}, sigma={self.sigma})" return s def _setup_size(size, error_msg): if isinstance(size, numbers.Number): return int(size), int(size) if isinstance(size, Sequence) and len(size) == 1: return size[0], size[0] if len(size) != 2: raise ValueError(error_msg) return size def _check_sequence_input(x, name, req_sizes): msg = req_sizes[0] if len(req_sizes) < 2 else " or ".join([str(s) for s in req_sizes]) if not isinstance(x, Sequence): raise TypeError(f"{name} should be a sequence of length {msg}.") if len(x) not in req_sizes: raise ValueError(f"{name} should be sequence of length {msg}.") def _setup_angle(x, name, req_sizes=(2,)): if isinstance(x, numbers.Number): if x < 0: raise ValueError(f"If {name} is a single number, it must be positive.") x = [-x, x] else: _check_sequence_input(x, name, req_sizes) return [float(d) for d in x] class RandomInvert(torch.nn.Module): """Inverts the colors of the given image randomly with a given probability. If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". Args: p (float): probability of the image being color inverted. Default value is 0.5 """ def __init__(self, p=0.5): super().__init__() _log_api_usage_once(self) self.p = p def forward(self, img): """ Args: img (PIL Image or Tensor): Image to be inverted. Returns: PIL Image or Tensor: Randomly color inverted image. """ if torch.rand(1).item() < self.p: return F.invert(img) return img def __repr__(self) -> str: return f"{self.__class__.__name__}(p={self.p})"
[docs]class RandomPosterize(torch.nn.Module): """Posterize the image randomly with a given probability by reducing the number of bits for each color channel. If the image is torch Tensor, it should be of type torch.uint8, and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". Args: bits (int): number of bits to keep for each channel (0-8) p (float): probability of the image being posterized. Default value is 0.5 """ def __init__(self, bits, p=0.5): super().__init__() _log_api_usage_once(self) self.bits = bits self.p = p
[docs] def forward(self, img): """ Args: img (PIL Image or Tensor): Image to be posterized. Returns: PIL Image or Tensor: Randomly posterized image. """ if torch.rand(1).item() < self.p: return F.posterize(img, self.bits) return img
def __repr__(self) -> str: return f"{self.__class__.__name__}(bits={self.bits},p={self.p})"
[docs]class RandomSolarize(torch.nn.Module): """Solarize the image randomly with a given probability by inverting all pixel values above a threshold. If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". Args: threshold (float): all pixels equal or above this value are inverted. p (float): probability of the image being solarized. Default value is 0.5 """ def __init__(self, threshold, p=0.5): super().__init__() _log_api_usage_once(self) self.threshold = threshold self.p = p
[docs] def forward(self, img): """ Args: img (PIL Image or Tensor): Image to be solarized. Returns: PIL Image or Tensor: Randomly solarized image. """ if torch.rand(1).item() < self.p: return F.solarize(img, self.threshold) return img
def __repr__(self) -> str: return f"{self.__class__.__name__}(threshold={self.threshold},p={self.p})"
class RandomAdjustSharpness(torch.nn.Module): """Adjust the sharpness of the image randomly with a given probability. If the image is torch Tensor, it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. Args: sharpness_factor (float): How much to adjust the sharpness. Can be any non negative number. 0 gives a blurred image, 1 gives the original image while 2 increases the sharpness by a factor of 2. p (float): probability of the image being sharpened. Default value is 0.5 """ def __init__(self, sharpness_factor, p=0.5): super().__init__() _log_api_usage_once(self) self.sharpness_factor = sharpness_factor self.p = p def forward(self, img): """ Args: img (PIL Image or Tensor): Image to be sharpened. Returns: PIL Image or Tensor: Randomly sharpened image. """ if torch.rand(1).item() < self.p: return F.adjust_sharpness(img, self.sharpness_factor) return img def __repr__(self) -> str: return f"{self.__class__.__name__}(sharpness_factor={self.sharpness_factor},p={self.p})" class RandomAutocontrast(torch.nn.Module): """Autocontrast the pixels of the given image randomly with a given probability. If the image is torch Tensor, it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". Args: p (float): probability of the image being autocontrasted. Default value is 0.5 """ def __init__(self, p=0.5): super().__init__() _log_api_usage_once(self) self.p = p def forward(self, img): """ Args: img (PIL Image or Tensor): Image to be autocontrasted. Returns: PIL Image or Tensor: Randomly autocontrasted image. """ if torch.rand(1).item() < self.p: return F.autocontrast(img) return img def __repr__(self) -> str: return f"{self.__class__.__name__}(p={self.p})" class RandomEqualize(torch.nn.Module): """Equalize the histogram of the given image randomly with a given probability. If the image is torch Tensor, it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "P", "L" or "RGB". Args: p (float): probability of the image being equalized. Default value is 0.5 """ def __init__(self, p=0.5): super().__init__() _log_api_usage_once(self) self.p = p def forward(self, img): """ Args: img (PIL Image or Tensor): Image to be equalized. Returns: PIL Image or Tensor: Randomly equalized image. """ if torch.rand(1).item() < self.p: return F.equalize(img) return img def __repr__(self) -> str: return f"{self.__class__.__name__}(p={self.p})"

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources