Source code for torchvision.datapoints._image

from __future__ import annotations

from typing import Any, List, Optional, Tuple, Union

import PIL.Image
import torch
from torchvision.transforms.functional import InterpolationMode

from ._datapoint import _FillTypeJIT, Datapoint

[docs]class Image(Datapoint): """[BETA] :class:`torch.Tensor` subclass for images. Args: data (tensor-like, PIL.Image.Image): Any data that can be turned into a tensor with :func:`torch.as_tensor` as well as PIL images. dtype (torch.dtype, optional): Desired data type of the bounding box. If omitted, will be inferred from ``data``. device (torch.device, optional): Desired device of the bounding box. If omitted and ``data`` is a :class:`torch.Tensor`, the device is taken from it. Otherwise, the bounding box is constructed on the CPU. requires_grad (bool, optional): Whether autograd should record operations on the bounding box. If omitted and ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``. """ @classmethod def _wrap(cls, tensor: torch.Tensor) -> Image: image = tensor.as_subclass(cls) return image def __new__( cls, data: Any, *, dtype: Optional[torch.dtype] = None, device: Optional[Union[torch.device, str, int]] = None, requires_grad: Optional[bool] = None, ) -> Image: if isinstance(data, PIL.Image.Image): from torchvision.transforms.v2 import functional as F data = F.pil_to_tensor(data) tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) if tensor.ndim < 2: raise ValueError elif tensor.ndim == 2: tensor = tensor.unsqueeze(0) return cls._wrap(tensor) @classmethod def wrap_like(cls, other: Image, tensor: torch.Tensor) -> Image: return cls._wrap(tensor) def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] return self._make_repr() @property def spatial_size(self) -> Tuple[int, int]: return tuple(self.shape[-2:]) # type: ignore[return-value] @property def num_channels(self) -> int: return self.shape[-3] def horizontal_flip(self) -> Image: output = self._F.horizontal_flip_image_tensor(self.as_subclass(torch.Tensor)) return Image.wrap_like(self, output) def vertical_flip(self) -> Image: output = self._F.vertical_flip_image_tensor(self.as_subclass(torch.Tensor)) return Image.wrap_like(self, output) def resize( # type: ignore[override] self, size: List[int], interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, max_size: Optional[int] = None, antialias: Optional[Union[str, bool]] = "warn", ) -> Image: output = self._F.resize_image_tensor( self.as_subclass(torch.Tensor), size, interpolation=interpolation, max_size=max_size, antialias=antialias ) return Image.wrap_like(self, output) def crop(self, top: int, left: int, height: int, width: int) -> Image: output = self._F.crop_image_tensor(self.as_subclass(torch.Tensor), top, left, height, width) return Image.wrap_like(self, output) def center_crop(self, output_size: List[int]) -> Image: output = self._F.center_crop_image_tensor(self.as_subclass(torch.Tensor), output_size=output_size) return Image.wrap_like(self, output) def resized_crop( self, top: int, left: int, height: int, width: int, size: List[int], interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, antialias: Optional[Union[str, bool]] = "warn", ) -> Image: output = self._F.resized_crop_image_tensor( self.as_subclass(torch.Tensor), top, left, height, width, size=list(size), interpolation=interpolation, antialias=antialias, ) return Image.wrap_like(self, output) def pad( self, padding: List[int], fill: Optional[Union[int, float, List[float]]] = None, padding_mode: str = "constant", ) -> Image: output = self._F.pad_image_tensor(self.as_subclass(torch.Tensor), padding, fill=fill, padding_mode=padding_mode) return Image.wrap_like(self, output) def rotate( self, angle: float, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, expand: bool = False, center: Optional[List[float]] = None, fill: _FillTypeJIT = None, ) -> Image: output = self._F.rotate_image_tensor( self.as_subclass(torch.Tensor), angle, interpolation=interpolation, expand=expand, fill=fill, center=center ) return Image.wrap_like(self, output) def affine( self, angle: Union[int, float], translate: List[float], scale: float, shear: List[float], interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, fill: _FillTypeJIT = None, center: Optional[List[float]] = None, ) -> Image: output = self._F.affine_image_tensor( self.as_subclass(torch.Tensor), angle, translate=translate, scale=scale, shear=shear, interpolation=interpolation, fill=fill, center=center, ) return Image.wrap_like(self, output) def perspective( self, startpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]], interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, fill: _FillTypeJIT = None, coefficients: Optional[List[float]] = None, ) -> Image: output = self._F.perspective_image_tensor( self.as_subclass(torch.Tensor), startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients, ) return Image.wrap_like(self, output) def elastic( self, displacement: torch.Tensor, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, fill: _FillTypeJIT = None, ) -> Image: output = self._F.elastic_image_tensor( self.as_subclass(torch.Tensor), displacement, interpolation=interpolation, fill=fill ) return Image.wrap_like(self, output) def rgb_to_grayscale(self, num_output_channels: int = 1) -> Image: output = self._F.rgb_to_grayscale_image_tensor( self.as_subclass(torch.Tensor), num_output_channels=num_output_channels ) return Image.wrap_like(self, output) def adjust_brightness(self, brightness_factor: float) -> Image: output = self._F.adjust_brightness_image_tensor( self.as_subclass(torch.Tensor), brightness_factor=brightness_factor ) return Image.wrap_like(self, output) def adjust_saturation(self, saturation_factor: float) -> Image: output = self._F.adjust_saturation_image_tensor( self.as_subclass(torch.Tensor), saturation_factor=saturation_factor ) return Image.wrap_like(self, output) def adjust_contrast(self, contrast_factor: float) -> Image: output = self._F.adjust_contrast_image_tensor(self.as_subclass(torch.Tensor), contrast_factor=contrast_factor) return Image.wrap_like(self, output) def adjust_sharpness(self, sharpness_factor: float) -> Image: output = self._F.adjust_sharpness_image_tensor( self.as_subclass(torch.Tensor), sharpness_factor=sharpness_factor ) return Image.wrap_like(self, output) def adjust_hue(self, hue_factor: float) -> Image: output = self._F.adjust_hue_image_tensor(self.as_subclass(torch.Tensor), hue_factor=hue_factor) return Image.wrap_like(self, output) def adjust_gamma(self, gamma: float, gain: float = 1) -> Image: output = self._F.adjust_gamma_image_tensor(self.as_subclass(torch.Tensor), gamma=gamma, gain=gain) return Image.wrap_like(self, output) def posterize(self, bits: int) -> Image: output = self._F.posterize_image_tensor(self.as_subclass(torch.Tensor), bits=bits) return Image.wrap_like(self, output) def solarize(self, threshold: float) -> Image: output = self._F.solarize_image_tensor(self.as_subclass(torch.Tensor), threshold=threshold) return Image.wrap_like(self, output) def autocontrast(self) -> Image: output = self._F.autocontrast_image_tensor(self.as_subclass(torch.Tensor)) return Image.wrap_like(self, output) def equalize(self) -> Image: output = self._F.equalize_image_tensor(self.as_subclass(torch.Tensor)) return Image.wrap_like(self, output) def invert(self) -> Image: output = self._F.invert_image_tensor(self.as_subclass(torch.Tensor)) return Image.wrap_like(self, output) def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Image: output = self._F.gaussian_blur_image_tensor( self.as_subclass(torch.Tensor), kernel_size=kernel_size, sigma=sigma ) return Image.wrap_like(self, output) def normalize(self, mean: List[float], std: List[float], inplace: bool = False) -> Image: output = self._F.normalize_image_tensor(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace) return Image.wrap_like(self, output)
_ImageType = Union[torch.Tensor, PIL.Image.Image, Image] _ImageTypeJIT = torch.Tensor _TensorImageType = Union[torch.Tensor, Image] _TensorImageTypeJIT = torch.Tensor


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources