[docs]classRandomErasing(_RandomApplyTransform):"""Randomly select a rectangle region in the input image or video and erase its pixels. This transform does not support PIL Image. 'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896 Args: p (float, optional): probability that the random erasing operation will be performed. scale (tuple of float, optional): range of proportion of erased area against input image. ratio (tuple of float, optional): range of aspect ratio of erased area. value (number or tuple of numbers): erasing value. Default is 0. If a single int, it is used to erase all pixels. If a tuple of length 3, it is used to erase R, G, B channels respectively. If a str of 'random', erasing each pixel with random values. inplace (bool, optional): boolean to make this transform inplace. Default set to False. Returns: Erased input. Example: >>> from torchvision.transforms import v2 as transforms >>> >>> transform = transforms.Compose([ >>> transforms.RandomHorizontalFlip(), >>> transforms.PILToTensor(), >>> transforms.ConvertImageDtype(torch.float), >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), >>> transforms.RandomErasing(), >>> ]) """_v1_transform_cls=_transforms.RandomErasingdef_extract_params_for_v1_transform(self)->Dict[str,Any]:returndict(super()._extract_params_for_v1_transform(),value="random"ifself.valueisNoneelseself.value,)def__init__(self,p:float=0.5,scale:Sequence[float]=(0.02,0.33),ratio:Sequence[float]=(0.3,3.3),value:float=0.0,inplace:bool=False,):super().__init__(p=p)ifnotisinstance(value,(numbers.Number,str,tuple,list)):raiseTypeError("Argument value should be either a number or str or a sequence")ifisinstance(value,str)andvalue!="random":raiseValueError("If value is str, it should be 'random'")ifnotisinstance(scale,Sequence):raiseTypeError("Scale should be a sequence")ifnotisinstance(ratio,Sequence):raiseTypeError("Ratio should be a sequence")if(scale[0]>scale[1])or(ratio[0]>ratio[1]):warnings.warn("Scale and ratio should be of kind (min, max)")ifscale[0]<0orscale[1]>1:raiseValueError("Scale should be between 0 and 1")self.scale=scaleself.ratio=ratioifisinstance(value,(int,float)):self.value=[float(value)]elifisinstance(value,str):self.value=Noneelifisinstance(value,(list,tuple)):self.value=[float(v)forvinvalue]else:self.value=valueself.inplace=inplaceself._log_ratio=torch.log(torch.tensor(self.ratio))def_call_kernel(self,functional:Callable,inpt:Any,*args:Any,**kwargs:Any)->Any:ifisinstance(inpt,(tv_tensors.BoundingBoxes,tv_tensors.Mask)):warnings.warn(f"{type(self).__name__}() is currently passing through inputs of type "f"tv_tensors.{type(inpt).__name__}. This will likely change in the future.")returnsuper()._call_kernel(functional,inpt,*args,**kwargs)
[docs]defmake_params(self,flat_inputs:List[Any])->Dict[str,Any]:img_c,img_h,img_w=query_chw(flat_inputs)ifself.valueisnotNoneandnot(len(self.value)in(1,img_c)):raiseValueError(f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)")area=img_h*img_wlog_ratio=self._log_ratiofor_inrange(10):erase_area=area*torch.empty(1).uniform_(self.scale[0],self.scale[1]).item()aspect_ratio=torch.exp(torch.empty(1).uniform_(log_ratio[0],# type: ignore[arg-type]log_ratio[1],# type: ignore[arg-type])).item()h=int(round(math.sqrt(erase_area*aspect_ratio)))w=int(round(math.sqrt(erase_area/aspect_ratio)))ifnot(h<img_handw<img_w):continueifself.valueisNone:v=torch.empty([img_c,h,w],dtype=torch.float32).normal_()else:v=torch.tensor(self.value)[:,None,None]i=torch.randint(0,img_h-h+1,size=(1,)).item()j=torch.randint(0,img_w-w+1,size=(1,)).item()breakelse:i,j,h,w,v=0,0,img_h,img_w,Nonereturndict(i=i,j=j,h=h,w=w,v=v)
class_BaseMixUpCutMix(Transform):def__init__(self,*,alpha:float=1.0,num_classes:Optional[int]=None,labels_getter="default")->None:super().__init__()self.alpha=float(alpha)self._dist=torch.distributions.Beta(torch.tensor([alpha]),torch.tensor([alpha]))self.num_classes=num_classesself._labels_getter=_parse_labels_getter(labels_getter)defforward(self,*inputs):inputs=inputsiflen(inputs)>1elseinputs[0]flat_inputs,spec=tree_flatten(inputs)needs_transform_list=self._needs_transform_list(flat_inputs)ifhas_any(flat_inputs,PIL.Image.Image,tv_tensors.BoundingBoxes,tv_tensors.Mask):raiseValueError(f"{type(self).__name__}() does not support PIL images, bounding boxes and masks.")labels=self._labels_getter(inputs)ifnotisinstance(labels,torch.Tensor):raiseValueError(f"The labels must be a tensor, but got {type(labels)} instead.")iflabels.ndimnotin(1,2):raiseValueError(f"labels should be index based with shape (batch_size,) "f"or probability based with shape (batch_size, num_classes), "f"but got a tensor of shape {labels.shape} instead.")iflabels.ndim==2andself.num_classesisnotNoneandlabels.shape[-1]!=self.num_classes:raiseValueError(f"When passing 2D labels, "f"the number of elements in last dimension must match num_classes: "f"{labels.shape[-1]} != {self.num_classes}. "f"You can Leave num_classes to None.")iflabels.ndim==1andself.num_classesisNone:raiseValueError("num_classes must be passed if the labels are index-based (1D)")params={"labels":labels,"batch_size":labels.shape[0],**self.make_params([inptfor(inpt,needs_transform)inzip(flat_inputs,needs_transform_list)ifneeds_transform]),}# By default, the labels will be False inside needs_transform_list, since they are a torch.Tensor coming# after an image or video. However, we need to handle them in _transform, so we make sure to set them to Trueneeds_transform_list[next(idxforidx,inptinenumerate(flat_inputs)ifinptislabels)]=Trueflat_outputs=[self.transform(inpt,params)ifneeds_transformelseinptfor(inpt,needs_transform)inzip(flat_inputs,needs_transform_list)]returntree_unflatten(flat_outputs,spec)def_check_image_or_video(self,inpt:torch.Tensor,*,batch_size:int):expected_num_dims=5ifisinstance(inpt,tv_tensors.Video)else4ifinpt.ndim!=expected_num_dims:raiseValueError(f"Expected a batched input with {expected_num_dims} dims, but got {inpt.ndim} dimensions instead.")ifinpt.shape[0]!=batch_size:raiseValueError(f"The batch size of the image or video does not match the batch size of the labels: "f"{inpt.shape[0]} != {batch_size}.")def_mixup_label(self,label:torch.Tensor,*,lam:float)->torch.Tensor:iflabel.ndim==1:label=one_hot(label,num_classes=self.num_classes)# type: ignore[arg-type]ifnotlabel.dtype.is_floating_point:label=label.float()returnlabel.roll(1,0).mul_(1.0-lam).add_(label.mul(lam))
[docs]classMixUp(_BaseMixUpCutMix):"""Apply MixUp to the provided batch of images and labels. Paper: `mixup: Beyond Empirical Risk Minimization <https://arxiv.org/abs/1710.09412>`_. .. note:: This transform is meant to be used on **batches** of samples, not individual images. See :ref:`sphx_glr_auto_examples_transforms_plot_cutmix_mixup.py` for detailed usage examples. The sample pairing is deterministic and done by matching consecutive samples in the batch, so the batch needs to be shuffled (this is an implementation detail, not a guaranteed convention.) In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed into a tensor of shape ``(batch_size, num_classes)``. Args: alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1. num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding. Can be None only if the labels are already one-hot-encoded. labels_getter (callable or "default", optional): indicates how to identify the labels in the input. By default, this will pick the second parameter as the labels if it's a tensor. This covers the most common scenario where this transform is called as ``MixUp()(imgs_batch, labels_batch)``. It can also be a callable that takes the same input as the transform, and returns the labels. """
[docs]classCutMix(_BaseMixUpCutMix):"""Apply CutMix to the provided batch of images and labels. Paper: `CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features <https://arxiv.org/abs/1905.04899>`_. .. note:: This transform is meant to be used on **batches** of samples, not individual images. See :ref:`sphx_glr_auto_examples_transforms_plot_cutmix_mixup.py` for detailed usage examples. The sample pairing is deterministic and done by matching consecutive samples in the batch, so the batch needs to be shuffled (this is an implementation detail, not a guaranteed convention.) In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed into a tensor of shape ``(batch_size, num_classes)``. Args: alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1. num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding. Can be None only if the labels are already one-hot-encoded. labels_getter (callable or "default", optional): indicates how to identify the labels in the input. By default, this will pick the second parameter as the labels if it's a tensor. This covers the most common scenario where this transform is called as ``CutMix()(imgs_batch, labels_batch)``. It can also be a callable that takes the same input as the transform, and returns the labels. """
[docs]classJPEG(Transform):"""Apply JPEG compression and decompression to the given images. If the input is a :class:`torch.Tensor`, it is expected to be of dtype uint8, on CPU, and have [..., 3 or 1, H, W] shape, where ... means an arbitrary number of leading dimensions. Args: quality (sequence or number): JPEG quality, from 1 to 100. Lower means more compression. If quality is a sequence like (min, max), it specifies the range of JPEG quality to randomly select from (inclusive of both ends). Returns: image with JPEG compression. """def__init__(self,quality:Union[int,Sequence[int]]):super().__init__()ifisinstance(quality,int):quality=[quality,quality]else:_check_sequence_input(quality,"quality",req_sizes=(2,))ifnot(1<=quality[0]<=quality[1]<=100andisinstance(quality[0],int)andisinstance(quality[1],int)):raiseValueError(f"quality must be an integer from 1 to 100, got {quality=}")self.quality=quality
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.