Source code for torchvision.transforms.v2._auto_augment
importmathfromtypingimportAny,Callable,cast,Dict,List,Optional,Tuple,Type,UnionimportPIL.Imageimporttorchfromtorch.utils._pytreeimporttree_flatten,tree_unflatten,TreeSpecfromtorchvisionimporttransformsas_transforms,tv_tensorsfromtorchvision.transformsimport_functional_tensoras_FTfromtorchvision.transforms.v2importAutoAugmentPolicy,functionalasF,InterpolationMode,Transformfromtorchvision.transforms.v2.functional._geometryimport_check_interpolationfromtorchvision.transforms.v2.functional._metaimportget_sizefromtorchvision.transforms.v2.functional._utilsimport_FillType,_FillTypeJITfrom._utilsimport_get_fill,_setup_fill_arg,check_type,is_pure_tensorImageOrVideo=Union[torch.Tensor,PIL.Image.Image,tv_tensors.Image,tv_tensors.Video]class_AutoAugmentBase(Transform):def__init__(self,*,interpolation:Union[InterpolationMode,int]=InterpolationMode.NEAREST,fill:Union[_FillType,Dict[Union[Type,str],_FillType]]=None,)->None:super().__init__()self.interpolation=_check_interpolation(interpolation)self.fill=fillself._fill=_setup_fill_arg(fill)def_extract_params_for_v1_transform(self)->Dict[str,Any]:params=super()._extract_params_for_v1_transform()ifisinstance(params["fill"],dict):raiseValueError(f"{type(self).__name__}() can not be scripted for when `fill` is a dictionary.")returnparamsdef_get_random_item(self,dct:Dict[str,Tuple[Callable,bool]])->Tuple[str,Tuple[Callable,bool]]:keys=tuple(dct.keys())key=keys[int(torch.randint(len(keys),()))]returnkey,dct[key]def_flatten_and_extract_image_or_video(self,inputs:Any,unsupported_types:Tuple[Type,...]=(tv_tensors.BoundingBoxes,tv_tensors.Mask),)->Tuple[Tuple[List[Any],TreeSpec,int],ImageOrVideo]:flat_inputs,spec=tree_flatten(inputsiflen(inputs)>1elseinputs[0])needs_transform_list=self._needs_transform_list(flat_inputs)image_or_videos=[]foridx,(inpt,needs_transform)inenumerate(zip(flat_inputs,needs_transform_list)):ifneeds_transformandcheck_type(inpt,(tv_tensors.Image,PIL.Image.Image,is_pure_tensor,tv_tensors.Video,),):image_or_videos.append((idx,inpt))elifisinstance(inpt,unsupported_types):raiseTypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")ifnotimage_or_videos:raiseTypeError("Found no image in the sample.")iflen(image_or_videos)>1:raiseTypeError(f"Auto augment transformations are only properly defined for a single image or video, "f"but found {len(image_or_videos)}.")idx,image_or_video=image_or_videos[0]return(flat_inputs,spec,idx),image_or_videodef_unflatten_and_insert_image_or_video(self,flat_inputs_with_spec:Tuple[List[Any],TreeSpec,int],image_or_video:ImageOrVideo,)->Any:flat_inputs,spec,idx=flat_inputs_with_specflat_inputs[idx]=image_or_videoreturntree_unflatten(flat_inputs,spec)def_apply_image_or_video_transform(self,image:ImageOrVideo,transform_id:str,magnitude:float,interpolation:Union[InterpolationMode,int],fill:Dict[Union[Type,str],_FillTypeJIT],)->ImageOrVideo:# Note: this cast is wrong and is only here to make mypy happy (it disagrees with torchscript)image=cast(torch.Tensor,image)fill_=_get_fill(fill,type(image))iftransform_id=="Identity":returnimageeliftransform_id=="ShearX":# magnitude should be arctan(magnitude)# official autoaug: (1, level, 0, 0, 1, 0)# https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290# compared to# torchvision: (1, tan(level), 0, 0, 1, 0)# https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976returnF.affine(image,angle=0.0,translate=[0,0],scale=1.0,shear=[math.degrees(math.atan(magnitude)),0.0],interpolation=interpolation,fill=fill_,center=[0,0],)eliftransform_id=="ShearY":# magnitude should be arctan(magnitude)# See abovereturnF.affine(image,angle=0.0,translate=[0,0],scale=1.0,shear=[0.0,math.degrees(math.atan(magnitude))],interpolation=interpolation,fill=fill_,center=[0,0],)eliftransform_id=="TranslateX":returnF.affine(image,angle=0.0,translate=[int(magnitude),0],scale=1.0,interpolation=interpolation,shear=[0.0,0.0],fill=fill_,)eliftransform_id=="TranslateY":returnF.affine(image,angle=0.0,translate=[0,int(magnitude)],scale=1.0,interpolation=interpolation,shear=[0.0,0.0],fill=fill_,)eliftransform_id=="Rotate":returnF.rotate(image,angle=magnitude,interpolation=interpolation,fill=fill_)eliftransform_id=="Brightness":returnF.adjust_brightness(image,brightness_factor=1.0+magnitude)eliftransform_id=="Color":returnF.adjust_saturation(image,saturation_factor=1.0+magnitude)eliftransform_id=="Contrast":returnF.adjust_contrast(image,contrast_factor=1.0+magnitude)eliftransform_id=="Sharpness":returnF.adjust_sharpness(image,sharpness_factor=1.0+magnitude)eliftransform_id=="Posterize":returnF.posterize(image,bits=int(magnitude))eliftransform_id=="Solarize":bound=_FT._max_value(image.dtype)ifisinstance(image,torch.Tensor)else255.0returnF.solarize(image,threshold=bound*magnitude)eliftransform_id=="AutoContrast":returnF.autocontrast(image)eliftransform_id=="Equalize":returnF.equalize(image)eliftransform_id=="Invert":returnF.invert(image)else:raiseValueError(f"No transform available for {transform_id}")
[docs]classAutoAugment(_AutoAugmentBase):r"""AutoAugment data augmentation method based on `"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_. This transformation works on images and videos only. If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". Args: policy (AutoAugmentPolicy, optional): Desired policy enum defined by :class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``. interpolation (InterpolationMode, optional): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. fill (sequence or number, optional): Pixel fill value for the area outside the transformed image. If given a number, the value is used for all bands respectively. """_v1_transform_cls=_transforms.AutoAugment_AUGMENTATION_SPACE={"ShearX":(lambdanum_bins,height,width:torch.linspace(0.0,0.3,num_bins),True),"ShearY":(lambdanum_bins,height,width:torch.linspace(0.0,0.3,num_bins),True),"TranslateX":(lambdanum_bins,height,width:torch.linspace(0.0,150.0/331.0*width,num_bins),True,),"TranslateY":(lambdanum_bins,height,width:torch.linspace(0.0,150.0/331.0*height,num_bins),True,),"Rotate":(lambdanum_bins,height,width:torch.linspace(0.0,30.0,num_bins),True),"Brightness":(lambdanum_bins,height,width:torch.linspace(0.0,0.9,num_bins),True),"Color":(lambdanum_bins,height,width:torch.linspace(0.0,0.9,num_bins),True),"Contrast":(lambdanum_bins,height,width:torch.linspace(0.0,0.9,num_bins),True),"Sharpness":(lambdanum_bins,height,width:torch.linspace(0.0,0.9,num_bins),True),"Posterize":(lambdanum_bins,height,width:(8-(torch.arange(num_bins)/((num_bins-1)/4))).round().int(),False,),"Solarize":(lambdanum_bins,height,width:torch.linspace(1.0,0.0,num_bins),False),"AutoContrast":(lambdanum_bins,height,width:None,False),"Equalize":(lambdanum_bins,height,width:None,False),"Invert":(lambdanum_bins,height,width:None,False),}def__init__(self,policy:AutoAugmentPolicy=AutoAugmentPolicy.IMAGENET,interpolation:Union[InterpolationMode,int]=InterpolationMode.NEAREST,fill:Union[_FillType,Dict[Union[Type,str],_FillType]]=None,)->None:super().__init__(interpolation=interpolation,fill=fill)self.policy=policyself._policies=self._get_policies(policy)def_get_policies(self,policy:AutoAugmentPolicy)->List[Tuple[Tuple[str,float,Optional[int]],Tuple[str,float,Optional[int]]]]:ifpolicy==AutoAugmentPolicy.IMAGENET:return[(("Posterize",0.4,8),("Rotate",0.6,9)),(("Solarize",0.6,5),("AutoContrast",0.6,None)),(("Equalize",0.8,None),("Equalize",0.6,None)),(("Posterize",0.6,7),("Posterize",0.6,6)),(("Equalize",0.4,None),("Solarize",0.2,4)),(("Equalize",0.4,None),("Rotate",0.8,8)),(("Solarize",0.6,3),("Equalize",0.6,None)),(("Posterize",0.8,5),("Equalize",1.0,None)),(("Rotate",0.2,3),("Solarize",0.6,8)),(("Equalize",0.6,None),("Posterize",0.4,6)),(("Rotate",0.8,8),("Color",0.4,0)),(("Rotate",0.4,9),("Equalize",0.6,None)),(("Equalize",0.0,None),("Equalize",0.8,None)),(("Invert",0.6,None),("Equalize",1.0,None)),(("Color",0.6,4),("Contrast",1.0,8)),(("Rotate",0.8,8),("Color",1.0,2)),(("Color",0.8,8),("Solarize",0.8,7)),(("Sharpness",0.4,7),("Invert",0.6,None)),(("ShearX",0.6,5),("Equalize",1.0,None)),(("Color",0.4,0),("Equalize",0.6,None)),(("Equalize",0.4,None),("Solarize",0.2,4)),(("Solarize",0.6,5),("AutoContrast",0.6,None)),(("Invert",0.6,None),("Equalize",1.0,None)),(("Color",0.6,4),("Contrast",1.0,8)),(("Equalize",0.8,None),("Equalize",0.6,None)),]elifpolicy==AutoAugmentPolicy.CIFAR10:return[(("Invert",0.1,None),("Contrast",0.2,6)),(("Rotate",0.7,2),("TranslateX",0.3,9)),(("Sharpness",0.8,1),("Sharpness",0.9,3)),(("ShearY",0.5,8),("TranslateY",0.7,9)),(("AutoContrast",0.5,None),("Equalize",0.9,None)),(("ShearY",0.2,7),("Posterize",0.3,7)),(("Color",0.4,3),("Brightness",0.6,7)),(("Sharpness",0.3,9),("Brightness",0.7,9)),(("Equalize",0.6,None),("Equalize",0.5,None)),(("Contrast",0.6,7),("Sharpness",0.6,5)),(("Color",0.7,7),("TranslateX",0.5,8)),(("Equalize",0.3,None),("AutoContrast",0.4,None)),(("TranslateY",0.4,3),("Sharpness",0.2,6)),(("Brightness",0.9,6),("Color",0.2,8)),(("Solarize",0.5,2),("Invert",0.0,None)),(("Equalize",0.2,None),("AutoContrast",0.6,None)),(("Equalize",0.2,None),("Equalize",0.6,None)),(("Color",0.9,9),("Equalize",0.6,None)),(("AutoContrast",0.8,None),("Solarize",0.2,8)),(("Brightness",0.1,3),("Color",0.7,0)),(("Solarize",0.4,5),("AutoContrast",0.9,None)),(("TranslateY",0.9,9),("TranslateY",0.7,9)),(("AutoContrast",0.9,None),("Solarize",0.8,3)),(("Equalize",0.8,None),("Invert",0.1,None)),(("TranslateY",0.7,9),("AutoContrast",0.9,None)),]elifpolicy==AutoAugmentPolicy.SVHN:return[(("ShearX",0.9,4),("Invert",0.2,None)),(("ShearY",0.9,8),("Invert",0.7,None)),(("Equalize",0.6,None),("Solarize",0.6,6)),(("Invert",0.9,None),("Equalize",0.6,None)),(("Equalize",0.6,None),("Rotate",0.9,3)),(("ShearX",0.9,4),("AutoContrast",0.8,None)),(("ShearY",0.9,8),("Invert",0.4,None)),(("ShearY",0.9,5),("Solarize",0.2,6)),(("Invert",0.9,None),("AutoContrast",0.8,None)),(("Equalize",0.6,None),("Rotate",0.9,3)),(("ShearX",0.9,4),("Solarize",0.3,3)),(("ShearY",0.8,8),("Invert",0.7,None)),(("Equalize",0.9,None),("TranslateY",0.6,6)),(("Invert",0.9,None),("Equalize",0.6,None)),(("Contrast",0.3,3),("Rotate",0.8,4)),(("Invert",0.8,None),("TranslateY",0.0,2)),(("ShearY",0.7,6),("Solarize",0.4,8)),(("Invert",0.6,None),("Rotate",0.8,4)),(("ShearY",0.3,7),("TranslateX",0.9,3)),(("ShearX",0.1,6),("Invert",0.6,None)),(("Solarize",0.7,2),("TranslateY",0.6,7)),(("ShearY",0.8,4),("Invert",0.8,None)),(("ShearX",0.7,9),("TranslateY",0.8,3)),(("ShearY",0.8,5),("AutoContrast",0.7,None)),(("ShearX",0.7,2),("Invert",0.1,None)),]else:raiseValueError(f"The provided policy {policy} is not recognized.")
[docs]classRandAugment(_AutoAugmentBase):r"""RandAugment data augmentation method based on `"RandAugment: Practical automated data augmentation with a reduced search space" <https://arxiv.org/abs/1909.13719>`_. This transformation works on images and videos only. If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". Args: num_ops (int, optional): Number of augmentation transformations to apply sequentially. magnitude (int, optional): Magnitude for all the transformations. num_magnitude_bins (int, optional): The number of different magnitude values. interpolation (InterpolationMode, optional): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. fill (sequence or number, optional): Pixel fill value for the area outside the transformed image. If given a number, the value is used for all bands respectively. """_v1_transform_cls=_transforms.RandAugment_AUGMENTATION_SPACE={"Identity":(lambdanum_bins,height,width:None,False),"ShearX":(lambdanum_bins,height,width:torch.linspace(0.0,0.3,num_bins),True),"ShearY":(lambdanum_bins,height,width:torch.linspace(0.0,0.3,num_bins),True),"TranslateX":(lambdanum_bins,height,width:torch.linspace(0.0,150.0/331.0*width,num_bins),True,),"TranslateY":(lambdanum_bins,height,width:torch.linspace(0.0,150.0/331.0*height,num_bins),True,),"Rotate":(lambdanum_bins,height,width:torch.linspace(0.0,30.0,num_bins),True),"Brightness":(lambdanum_bins,height,width:torch.linspace(0.0,0.9,num_bins),True),"Color":(lambdanum_bins,height,width:torch.linspace(0.0,0.9,num_bins),True),"Contrast":(lambdanum_bins,height,width:torch.linspace(0.0,0.9,num_bins),True),"Sharpness":(lambdanum_bins,height,width:torch.linspace(0.0,0.9,num_bins),True),"Posterize":(lambdanum_bins,height,width:(8-(torch.arange(num_bins)/((num_bins-1)/4))).round().int(),False,),"Solarize":(lambdanum_bins,height,width:torch.linspace(1.0,0.0,num_bins),False),"AutoContrast":(lambdanum_bins,height,width:None,False),"Equalize":(lambdanum_bins,height,width:None,False),}def__init__(self,num_ops:int=2,magnitude:int=9,num_magnitude_bins:int=31,interpolation:Union[InterpolationMode,int]=InterpolationMode.NEAREST,fill:Union[_FillType,Dict[Union[Type,str],_FillType]]=None,)->None:super().__init__(interpolation=interpolation,fill=fill)self.num_ops=num_opsself.magnitude=magnitudeself.num_magnitude_bins=num_magnitude_bins
[docs]classTrivialAugmentWide(_AutoAugmentBase):r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`_. This transformation works on images and videos only. If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". Args: num_magnitude_bins (int, optional): The number of different magnitude values. interpolation (InterpolationMode, optional): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. fill (sequence or number, optional): Pixel fill value for the area outside the transformed image. If given a number, the value is used for all bands respectively. """_v1_transform_cls=_transforms.TrivialAugmentWide_AUGMENTATION_SPACE={"Identity":(lambdanum_bins,height,width:None,False),"ShearX":(lambdanum_bins,height,width:torch.linspace(0.0,0.99,num_bins),True),"ShearY":(lambdanum_bins,height,width:torch.linspace(0.0,0.99,num_bins),True),"TranslateX":(lambdanum_bins,height,width:torch.linspace(0.0,32.0,num_bins),True),"TranslateY":(lambdanum_bins,height,width:torch.linspace(0.0,32.0,num_bins),True),"Rotate":(lambdanum_bins,height,width:torch.linspace(0.0,135.0,num_bins),True),"Brightness":(lambdanum_bins,height,width:torch.linspace(0.0,0.99,num_bins),True),"Color":(lambdanum_bins,height,width:torch.linspace(0.0,0.99,num_bins),True),"Contrast":(lambdanum_bins,height,width:torch.linspace(0.0,0.99,num_bins),True),"Sharpness":(lambdanum_bins,height,width:torch.linspace(0.0,0.99,num_bins),True),"Posterize":(lambdanum_bins,height,width:(8-(torch.arange(num_bins)/((num_bins-1)/6))).round().int(),False,),"Solarize":(lambdanum_bins,height,width:torch.linspace(1.0,0.0,num_bins),False),"AutoContrast":(lambdanum_bins,height,width:None,False),"Equalize":(lambdanum_bins,height,width:None,False),}def__init__(self,num_magnitude_bins:int=31,interpolation:Union[InterpolationMode,int]=InterpolationMode.NEAREST,fill:Union[_FillType,Dict[Union[Type,str],_FillType]]=None,):super().__init__(interpolation=interpolation,fill=fill)self.num_magnitude_bins=num_magnitude_bins
[docs]classAugMix(_AutoAugmentBase):r"""AugMix data augmentation method based on `"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" <https://arxiv.org/abs/1912.02781>`_. This transformation works on images and videos only. If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". Args: severity (int, optional): The severity of base augmentation operators. Default is ``3``. mixture_width (int, optional): The number of augmentation chains. Default is ``3``. chain_depth (int, optional): The depth of augmentation chains. A negative value denotes stochastic depth sampled from the interval [1, 3]. Default is ``-1``. alpha (float, optional): The hyperparameter for the probability distributions. Default is ``1.0``. all_ops (bool, optional): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``. interpolation (InterpolationMode, optional): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. fill (sequence or number, optional): Pixel fill value for the area outside the transformed image. If given a number, the value is used for all bands respectively. """_v1_transform_cls=_transforms.AugMix_PARTIAL_AUGMENTATION_SPACE={"ShearX":(lambdanum_bins,height,width:torch.linspace(0.0,0.3,num_bins),True),"ShearY":(lambdanum_bins,height,width:torch.linspace(0.0,0.3,num_bins),True),"TranslateX":(lambdanum_bins,height,width:torch.linspace(0.0,width/3.0,num_bins),True),"TranslateY":(lambdanum_bins,height,width:torch.linspace(0.0,height/3.0,num_bins),True),"Rotate":(lambdanum_bins,height,width:torch.linspace(0.0,30.0,num_bins),True),"Posterize":(lambdanum_bins,height,width:(4-(torch.arange(num_bins)/((num_bins-1)/4))).round().int(),False,),"Solarize":(lambdanum_bins,height,width:torch.linspace(1.0,0.0,num_bins),False),"AutoContrast":(lambdanum_bins,height,width:None,False),"Equalize":(lambdanum_bins,height,width:None,False),}_AUGMENTATION_SPACE:Dict[str,Tuple[Callable[[int,int,int],Optional[torch.Tensor]],bool]]={**_PARTIAL_AUGMENTATION_SPACE,"Brightness":(lambdanum_bins,height,width:torch.linspace(0.0,0.9,num_bins),True),"Color":(lambdanum_bins,height,width:torch.linspace(0.0,0.9,num_bins),True),"Contrast":(lambdanum_bins,height,width:torch.linspace(0.0,0.9,num_bins),True),"Sharpness":(lambdanum_bins,height,width:torch.linspace(0.0,0.9,num_bins),True),}def__init__(self,severity:int=3,mixture_width:int=3,chain_depth:int=-1,alpha:float=1.0,all_ops:bool=True,interpolation:Union[InterpolationMode,int]=InterpolationMode.BILINEAR,fill:Union[_FillType,Dict[Union[Type,str],_FillType]]=None,)->None:super().__init__(interpolation=interpolation,fill=fill)self._PARAMETER_MAX=10ifnot(1<=severity<=self._PARAMETER_MAX):raiseValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.")self.severity=severityself.mixture_width=mixture_widthself.chain_depth=chain_depthself.alpha=alphaself.all_ops=all_opsdef_sample_dirichlet(self,params:torch.Tensor)->torch.Tensor:# Must be on a separate method so that we can overwrite it in tests.returntorch._sample_dirichlet(params)
[docs]defforward(self,*inputs:Any)->Any:flat_inputs_with_spec,orig_image_or_video=self._flatten_and_extract_image_or_video(inputs)height,width=get_size(orig_image_or_video)# type: ignore[arg-type]ifisinstance(orig_image_or_video,torch.Tensor):image_or_video=orig_image_or_videoelse:# isinstance(inpt, PIL.Image.Image):image_or_video=F.pil_to_tensor(orig_image_or_video)augmentation_space=self._AUGMENTATION_SPACEifself.all_opselseself._PARTIAL_AUGMENTATION_SPACEorig_dims=list(image_or_video.shape)expected_ndim=5ifisinstance(orig_image_or_video,tv_tensors.Video)else4batch=image_or_video.reshape([1]*max(expected_ndim-image_or_video.ndim,0)+orig_dims)batch_dims=[batch.size(0)]+[1]*(batch.ndim-1)# Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a# Dirichlet with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of# augmented image or video.m=self._sample_dirichlet(torch.tensor([self.alpha,self.alpha],device=batch.device).expand(batch_dims[0],-1))# Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images or videos.combined_weights=self._sample_dirichlet(torch.tensor([self.alpha]*self.mixture_width,device=batch.device).expand(batch_dims[0],-1))*m[:,1].reshape([batch_dims[0],-1])mix=m[:,0].reshape(batch_dims)*batchforiinrange(self.mixture_width):aug=batchdepth=self.chain_depthifself.chain_depth>0elseint(torch.randint(low=1,high=4,size=(1,)).item())for_inrange(depth):transform_id,(magnitudes_fn,signed)=self._get_random_item(augmentation_space)magnitudes=magnitudes_fn(self._PARAMETER_MAX,height,width)ifmagnitudesisnotNone:magnitude=float(magnitudes[int(torch.randint(self.severity,()))])ifsignedandtorch.rand(())<=0.5:magnitude*=-1else:magnitude=0.0aug=self._apply_image_or_video_transform(aug,transform_id,magnitude,interpolation=self.interpolation,fill=self._fill)# type: ignore[assignment]mix.add_(combined_weights[:,i].reshape(batch_dims)*aug)mix=mix.reshape(orig_dims).to(dtype=image_or_video.dtype)ifisinstance(orig_image_or_video,(tv_tensors.Image,tv_tensors.Video)):mix=tv_tensors.wrap(mix,like=orig_image_or_video)elifisinstance(orig_image_or_video,PIL.Image.Image):mix=F.to_pil_image(mix)returnself._unflatten_and_insert_image_or_video(flat_inputs_with_spec,mix)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.