Source code for torchvision.transforms.v2._color
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import torch
from torchvision import transforms as _transforms
from torchvision.transforms.v2 import functional as F, Transform
from ._transform import _RandomApplyTransform
from ._utils import query_chw
[docs]class Grayscale(Transform):
"""Convert images or videos to grayscale.
If the input is a :class:`torch.Tensor`, it is expected
to have [..., 3 or 1, H, W] shape, where ... means an arbitrary number of leading dimensions
num_output_channels (int): (1 or 3) number of channels desired for output image
_v1_transform_cls = _transforms.Grayscale
def __init__(self, num_output_channels: int = 1):
self.num_output_channels = num_output_channels
[docs] def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=self.num_output_channels)
[docs]class RandomGrayscale(_RandomApplyTransform):
"""Randomly convert image or videos to grayscale with a probability of p (default 0.1).
If the input is a :class:`torch.Tensor`, it is expected to have [..., 3 or 1, H, W] shape,
where ... means an arbitrary number of leading dimensions
The output has the same number of channels as the input.
p (float): probability that image should be converted to grayscale.
_v1_transform_cls = _transforms.RandomGrayscale
def __init__(self, p: float = 0.1) -> None:
[docs] def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
num_input_channels, *_ = query_chw(flat_inputs)
return dict(num_input_channels=num_input_channels)
[docs] def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=params["num_input_channels"])
[docs]class RGB(Transform):
"""Convert images or videos to RGB (if they are already not RGB).
If the input is a :class:`torch.Tensor`, it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions
def __init__(self):
[docs] def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.grayscale_to_rgb, inpt)
[docs]class ColorJitter(Transform):
"""Randomly change the brightness, contrast, saturation and hue of an image or video.
If the input is a :class:`torch.Tensor`, it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
brightness (float or tuple of float (min, max)): How much to jitter brightness.
brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
or the given [min, max]. Should be non negative numbers.
contrast (float or tuple of float (min, max)): How much to jitter contrast.
contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
or the given [min, max]. Should be non-negative numbers.
saturation (float or tuple of float (min, max)): How much to jitter saturation.
saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
or the given [min, max]. Should be non negative numbers.
hue (float or tuple of float (min, max)): How much to jitter hue.
hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space;
thus it does not work if you normalize your image to an interval with negative values,
or use an interpolation that generates negative values before using this function.
_v1_transform_cls = _transforms.ColorJitter
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
return {attr: value or 0 for attr, value in super()._extract_params_for_v1_transform().items()}
def __init__(
brightness: Optional[Union[float, Sequence[float]]] = None,
contrast: Optional[Union[float, Sequence[float]]] = None,
saturation: Optional[Union[float, Sequence[float]]] = None,
hue: Optional[Union[float, Sequence[float]]] = None,
) -> None:
self.brightness = self._check_input(brightness, "brightness")
self.contrast = self._check_input(contrast, "contrast")
self.saturation = self._check_input(saturation, "saturation")
self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
def _check_input(
value: Optional[Union[float, Sequence[float]]],
name: str,
center: float = 1.0,
bound: Tuple[float, float] = (0, float("inf")),
clip_first_on_zero: bool = True,
) -> Optional[Tuple[float, float]]:
if value is None:
return None
if isinstance(value, (int, float)):
if value < 0:
raise ValueError(f"If {name} is a single number, it must be non negative.")
value = [center - value, center + value]
if clip_first_on_zero:
value[0] = max(value[0], 0.0)
elif isinstance(value, and len(value) == 2:
value = [float(v) for v in value]
raise TypeError(f"{name}={value} should be a single number or a sequence with length 2.")
if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError(f"{name} values should be between {bound}, but got {value}.")
return None if value[0] == value[1] == center else (float(value[0]), float(value[1]))
def _generate_value(left: float, right: float) -> float:
return torch.empty(1).uniform_(left, right).item()
[docs] def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
fn_idx = torch.randperm(4)
b = None if self.brightness is None else self._generate_value(self.brightness[0], self.brightness[1])
c = None if self.contrast is None else self._generate_value(self.contrast[0], self.contrast[1])
s = None if self.saturation is None else self._generate_value(self.saturation[0], self.saturation[1])
h = None if self.hue is None else self._generate_value(self.hue[0], self.hue[1])
return dict(fn_idx=fn_idx, brightness_factor=b, contrast_factor=c, saturation_factor=s, hue_factor=h)
[docs] def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
output = inpt
brightness_factor = params["brightness_factor"]
contrast_factor = params["contrast_factor"]
saturation_factor = params["saturation_factor"]
hue_factor = params["hue_factor"]
for fn_id in params["fn_idx"]:
if fn_id == 0 and brightness_factor is not None:
output = self._call_kernel(F.adjust_brightness, output, brightness_factor=brightness_factor)
elif fn_id == 1 and contrast_factor is not None:
output = self._call_kernel(F.adjust_contrast, output, contrast_factor=contrast_factor)
elif fn_id == 2 and saturation_factor is not None:
output = self._call_kernel(F.adjust_saturation, output, saturation_factor=saturation_factor)
elif fn_id == 3 and hue_factor is not None:
output = self._call_kernel(F.adjust_hue, output, hue_factor=hue_factor)
return output
[docs]class RandomChannelPermutation(Transform):
"""Randomly permute the channels of an image or video"""
[docs] def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
num_channels, *_ = query_chw(flat_inputs)
return dict(permutation=torch.randperm(num_channels))
[docs] def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.permute_channels, inpt, params["permutation"])
[docs]class RandomPhotometricDistort(Transform):
"""Randomly distorts the image or video as used in `SSD: Single Shot
MultiBox Detector <>`_.
This transform relies on :class:`~torchvision.transforms.v2.ColorJitter`
under the hood to adjust the contrast, saturation, hue, brightness, and also
randomly permutes channels.
brightness (tuple of float (min, max), optional): How much to jitter brightness.
brightness_factor is chosen uniformly from [min, max]. Should be non negative numbers.
contrast (tuple of float (min, max), optional): How much to jitter contrast.
contrast_factor is chosen uniformly from [min, max]. Should be non-negative numbers.
saturation (tuple of float (min, max), optional): How much to jitter saturation.
saturation_factor is chosen uniformly from [min, max]. Should be non negative numbers.
hue (tuple of float (min, max), optional): How much to jitter hue.
hue_factor is chosen uniformly from [min, max]. Should have -0.5 <= min <= max <= 0.5.
To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space;
thus it does not work if you normalize your image to an interval with negative values,
or use an interpolation that generates negative values before using this function.
p (float, optional) probability each distortion operation (contrast, saturation, ...) to be applied.
Default is 0.5.
def __init__(
brightness: Tuple[float, float] = (0.875, 1.125),
contrast: Tuple[float, float] = (0.5, 1.5),
saturation: Tuple[float, float] = (0.5, 1.5),
hue: Tuple[float, float] = (-0.05, 0.05),
p: float = 0.5,
self.brightness = brightness
self.contrast = contrast
self.hue = hue
self.saturation = saturation
self.p = p
[docs] def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
num_channels, *_ = query_chw(flat_inputs)
params: Dict[str, Any] = {
key: ColorJitter._generate_value(range[0], range[1]) if torch.rand(1) < self.p else None
for key, range in [
("brightness_factor", self.brightness),
("contrast_factor", self.contrast),
("saturation_factor", self.saturation),
("hue_factor", self.hue),
params["contrast_before"] = bool(torch.rand(()) < 0.5)
params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None
return params
[docs] def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["brightness_factor"] is not None:
inpt = self._call_kernel(F.adjust_brightness, inpt, brightness_factor=params["brightness_factor"])
if params["contrast_factor"] is not None and params["contrast_before"]:
inpt = self._call_kernel(F.adjust_contrast, inpt, contrast_factor=params["contrast_factor"])
if params["saturation_factor"] is not None:
inpt = self._call_kernel(F.adjust_saturation, inpt, saturation_factor=params["saturation_factor"])
if params["hue_factor"] is not None:
inpt = self._call_kernel(F.adjust_hue, inpt, hue_factor=params["hue_factor"])
if params["contrast_factor"] is not None and not params["contrast_before"]:
inpt = self._call_kernel(F.adjust_contrast, inpt, contrast_factor=params["contrast_factor"])
if params["channel_permutation"] is not None:
inpt = self._call_kernel(F.permute_channels, inpt, permutation=params["channel_permutation"])
return inpt
[docs]class RandomEqualize(_RandomApplyTransform):
"""Equalize the histogram of the given image or video with a given probability.
If the input is a :class:`torch.Tensor`, it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
p (float): probability of the image being equalized. Default value is 0.5
_v1_transform_cls = _transforms.RandomEqualize
[docs] def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.equalize, inpt)
[docs]class RandomInvert(_RandomApplyTransform):
"""Inverts the colors of the given image or video with a given probability.
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
p (float): probability of the image being color inverted. Default value is 0.5
_v1_transform_cls = _transforms.RandomInvert
[docs] def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.invert, inpt)
[docs]class RandomPosterize(_RandomApplyTransform):
"""Posterize the image or video with a given probability by reducing the
number of bits for each color channel.
If the input is a :class:`torch.Tensor`, it should be of type torch.uint8,
and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
bits (int): number of bits to keep for each channel (0-8)
p (float): probability of the image being posterized. Default value is 0.5
_v1_transform_cls = _transforms.RandomPosterize
def __init__(self, bits: int, p: float = 0.5) -> None:
self.bits = bits
[docs] def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.posterize, inpt, bits=self.bits)
[docs]class RandomSolarize(_RandomApplyTransform):
"""Solarize the image or video with a given probability by inverting all pixel
values above a threshold.
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
threshold (float): all pixels equal or above this value are inverted.
p (float): probability of the image being solarized. Default value is 0.5
_v1_transform_cls = _transforms.RandomSolarize
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
params = super()._extract_params_for_v1_transform()
params["threshold"] = float(params["threshold"])
return params
def __init__(self, threshold: float, p: float = 0.5) -> None:
self.threshold = threshold
[docs] def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.solarize, inpt, threshold=self.threshold)
[docs]class RandomAutocontrast(_RandomApplyTransform):
"""Autocontrast the pixels of the given image or video with a given probability.
If the input is a :class:`torch.Tensor`, it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
p (float): probability of the image being autocontrasted. Default value is 0.5
_v1_transform_cls = _transforms.RandomAutocontrast
[docs] def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.autocontrast, inpt)
[docs]class RandomAdjustSharpness(_RandomApplyTransform):
"""Adjust the sharpness of the image or video with a given probability.
If the input is a :class:`torch.Tensor`,
it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
sharpness_factor (float): How much to adjust the sharpness. Can be
any non-negative number. 0 gives a blurred image, 1 gives the
original image while 2 increases the sharpness by a factor of 2.
p (float): probability of the image being sharpened. Default value is 0.5
_v1_transform_cls = _transforms.RandomAdjustSharpness
def __init__(self, sharpness_factor: float, p: float = 0.5) -> None:
self.sharpness_factor = sharpness_factor
[docs] def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=self.sharpness_factor)