Shortcuts

Source code for torchvision.transforms.v2._container

from typing import Any, Callable, Dict, List, Optional, Sequence, Union

import torch

from torch import nn
from torchvision import transforms as _transforms
from torchvision.transforms.v2 import Transform


[docs]class Compose(Transform): """[BETA] Composes several transforms together. .. v2betastatus:: Compose transform This transform does not support torchscript. Please, see the note below. Args: transforms (list of ``Transform`` objects): list of transforms to compose. Example: >>> transforms.Compose([ >>> transforms.CenterCrop(10), >>> transforms.PILToTensor(), >>> transforms.ConvertImageDtype(torch.float), >>> ]) .. note:: In order to script the transformations, please use ``torch.nn.Sequential`` as below. >>> transforms = torch.nn.Sequential( >>> transforms.CenterCrop(10), >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), >>> ) >>> scripted_transforms = torch.jit.script(transforms) Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require `lambda` functions or ``PIL.Image``. """ def __init__(self, transforms: Sequence[Callable]) -> None: super().__init__() if not isinstance(transforms, Sequence): raise TypeError("Argument transforms should be a sequence of callables") self.transforms = transforms
[docs] def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] for transform in self.transforms: sample = transform(sample) return sample
[docs] def extra_repr(self) -> str: format_string = [] for t in self.transforms: format_string.append(f" {t}") return "\n".join(format_string)
[docs]class RandomApply(Transform): """[BETA] Apply randomly a list of transformations with a given probability. .. v2betastatus:: RandomApply transform .. note:: In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of transforms as shown below: >>> transforms = transforms.RandomApply(torch.nn.ModuleList([ >>> transforms.ColorJitter(), >>> ]), p=0.3) >>> scripted_transforms = torch.jit.script(transforms) Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require `lambda` functions or ``PIL.Image``. Args: transforms (sequence or torch.nn.Module): list of transformations p (float): probability of applying the list of transforms """ _v1_transform_cls = _transforms.RandomApply def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: float = 0.5) -> None: super().__init__() if not isinstance(transforms, (Sequence, nn.ModuleList)): raise TypeError("Argument transforms should be a sequence of callables or a `nn.ModuleList`") self.transforms = transforms if not (0.0 <= p <= 1.0): raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].") self.p = p def _extract_params_for_v1_transform(self) -> Dict[str, Any]: return {"transforms": self.transforms, "p": self.p}
[docs] def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] if torch.rand(1) >= self.p: return sample for transform in self.transforms: sample = transform(sample) return sample
[docs] def extra_repr(self) -> str: format_string = [] for t in self.transforms: format_string.append(f" {t}") return "\n".join(format_string)
[docs]class RandomChoice(Transform): """[BETA] Apply single transformation randomly picked from a list. .. v2betastatus:: RandomChoice transform This transform does not support torchscript. Args: transforms (sequence or torch.nn.Module): list of transformations p (list of floats or None, optional): probability of each transform being picked. If ``p`` doesn't sum to 1, it is automatically normalized. If ``None`` (default), all transforms have the same probability. """ def __init__( self, transforms: Sequence[Callable], p: Optional[List[float]] = None, ) -> None: if not isinstance(transforms, Sequence): raise TypeError("Argument transforms should be a sequence of callables") if p is None: p = [1] * len(transforms) elif len(p) != len(transforms): raise ValueError(f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}") super().__init__() self.transforms = transforms total = sum(p) self.p = [prob / total for prob in p]
[docs] def forward(self, *inputs: Any) -> Any: idx = int(torch.multinomial(torch.tensor(self.p), 1)) transform = self.transforms[idx] return transform(*inputs)
[docs]class RandomOrder(Transform): """[BETA] Apply a list of transformations in a random order. .. v2betastatus:: RandomOrder transform This transform does not support torchscript. Args: transforms (sequence or torch.nn.Module): list of transformations """ def __init__(self, transforms: Sequence[Callable]) -> None: if not isinstance(transforms, Sequence): raise TypeError("Argument transforms should be a sequence of callables") super().__init__() self.transforms = transforms
[docs] def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] for idx in torch.randperm(len(self.transforms)): transform = self.transforms[idx] sample = transform(sample) return sample

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources