Source code for torchvision.transforms.autoaugment
importmathfromenumimportEnumfromtypingimportDict,List,Optional,TupleimporttorchfromtorchimportTensorfrom.importfunctionalasF,InterpolationMode__all__=["AutoAugmentPolicy","AutoAugment","RandAugment","TrivialAugmentWide","AugMix"]def_apply_op(img:Tensor,op_name:str,magnitude:float,interpolation:InterpolationMode,fill:Optional[List[float]]):ifop_name=="ShearX":# magnitude should be arctan(magnitude)# official autoaug: (1, level, 0, 0, 1, 0)# https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290# compared to# torchvision: (1, tan(level), 0, 0, 1, 0)# https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976img=F.affine(img,angle=0.0,translate=[0,0],scale=1.0,shear=[math.degrees(math.atan(magnitude)),0.0],interpolation=interpolation,fill=fill,center=[0,0],)elifop_name=="ShearY":# magnitude should be arctan(magnitude)# See aboveimg=F.affine(img,angle=0.0,translate=[0,0],scale=1.0,shear=[0.0,math.degrees(math.atan(magnitude))],interpolation=interpolation,fill=fill,center=[0,0],)elifop_name=="TranslateX":img=F.affine(img,angle=0.0,translate=[int(magnitude),0],scale=1.0,interpolation=interpolation,shear=[0.0,0.0],fill=fill,)elifop_name=="TranslateY":img=F.affine(img,angle=0.0,translate=[0,int(magnitude)],scale=1.0,interpolation=interpolation,shear=[0.0,0.0],fill=fill,)elifop_name=="Rotate":img=F.rotate(img,magnitude,interpolation=interpolation,fill=fill)elifop_name=="Brightness":img=F.adjust_brightness(img,1.0+magnitude)elifop_name=="Color":img=F.adjust_saturation(img,1.0+magnitude)elifop_name=="Contrast":img=F.adjust_contrast(img,1.0+magnitude)elifop_name=="Sharpness":img=F.adjust_sharpness(img,1.0+magnitude)elifop_name=="Posterize":img=F.posterize(img,int(magnitude))elifop_name=="Solarize":img=F.solarize(img,magnitude)elifop_name=="AutoContrast":img=F.autocontrast(img)elifop_name=="Equalize":img=F.equalize(img)elifop_name=="Invert":img=F.invert(img)elifop_name=="Identity":passelse:raiseValueError(f"The provided operator {op_name} is not recognized.")returnimg
[docs]classAutoAugmentPolicy(Enum):"""AutoAugment policies learned on different datasets. Available policies are IMAGENET, CIFAR10 and SVHN. """IMAGENET="imagenet"CIFAR10="cifar10"SVHN="svhn"
# FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class
[docs]classAutoAugment(torch.nn.Module):r"""AutoAugment data augmentation method based on `"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_. If the image is torch Tensor, it should be of type torch.uint8, and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". Args: policy (AutoAugmentPolicy): Desired policy enum defined by :class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``. interpolation (InterpolationMode): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. fill (sequence or number, optional): Pixel fill value for the area outside the transformed image. If given a number, the value is used for all bands respectively. """def__init__(self,policy:AutoAugmentPolicy=AutoAugmentPolicy.IMAGENET,interpolation:InterpolationMode=InterpolationMode.NEAREST,fill:Optional[List[float]]=None,)->None:super().__init__()self.policy=policyself.interpolation=interpolationself.fill=fillself.policies=self._get_policies(policy)def_get_policies(self,policy:AutoAugmentPolicy)->List[Tuple[Tuple[str,float,Optional[int]],Tuple[str,float,Optional[int]]]]:ifpolicy==AutoAugmentPolicy.IMAGENET:return[(("Posterize",0.4,8),("Rotate",0.6,9)),(("Solarize",0.6,5),("AutoContrast",0.6,None)),(("Equalize",0.8,None),("Equalize",0.6,None)),(("Posterize",0.6,7),("Posterize",0.6,6)),(("Equalize",0.4,None),("Solarize",0.2,4)),(("Equalize",0.4,None),("Rotate",0.8,8)),(("Solarize",0.6,3),("Equalize",0.6,None)),(("Posterize",0.8,5),("Equalize",1.0,None)),(("Rotate",0.2,3),("Solarize",0.6,8)),(("Equalize",0.6,None),("Posterize",0.4,6)),(("Rotate",0.8,8),("Color",0.4,0)),(("Rotate",0.4,9),("Equalize",0.6,None)),(("Equalize",0.0,None),("Equalize",0.8,None)),(("Invert",0.6,None),("Equalize",1.0,None)),(("Color",0.6,4),("Contrast",1.0,8)),(("Rotate",0.8,8),("Color",1.0,2)),(("Color",0.8,8),("Solarize",0.8,7)),(("Sharpness",0.4,7),("Invert",0.6,None)),(("ShearX",0.6,5),("Equalize",1.0,None)),(("Color",0.4,0),("Equalize",0.6,None)),(("Equalize",0.4,None),("Solarize",0.2,4)),(("Solarize",0.6,5),("AutoContrast",0.6,None)),(("Invert",0.6,None),("Equalize",1.0,None)),(("Color",0.6,4),("Contrast",1.0,8)),(("Equalize",0.8,None),("Equalize",0.6,None)),]elifpolicy==AutoAugmentPolicy.CIFAR10:return[(("Invert",0.1,None),("Contrast",0.2,6)),(("Rotate",0.7,2),("TranslateX",0.3,9)),(("Sharpness",0.8,1),("Sharpness",0.9,3)),(("ShearY",0.5,8),("TranslateY",0.7,9)),(("AutoContrast",0.5,None),("Equalize",0.9,None)),(("ShearY",0.2,7),("Posterize",0.3,7)),(("Color",0.4,3),("Brightness",0.6,7)),(("Sharpness",0.3,9),("Brightness",0.7,9)),(("Equalize",0.6,None),("Equalize",0.5,None)),(("Contrast",0.6,7),("Sharpness",0.6,5)),(("Color",0.7,7),("TranslateX",0.5,8)),(("Equalize",0.3,None),("AutoContrast",0.4,None)),(("TranslateY",0.4,3),("Sharpness",0.2,6)),(("Brightness",0.9,6),("Color",0.2,8)),(("Solarize",0.5,2),("Invert",0.0,None)),(("Equalize",0.2,None),("AutoContrast",0.6,None)),(("Equalize",0.2,None),("Equalize",0.6,None)),(("Color",0.9,9),("Equalize",0.6,None)),(("AutoContrast",0.8,None),("Solarize",0.2,8)),(("Brightness",0.1,3),("Color",0.7,0)),(("Solarize",0.4,5),("AutoContrast",0.9,None)),(("TranslateY",0.9,9),("TranslateY",0.7,9)),(("AutoContrast",0.9,None),("Solarize",0.8,3)),(("Equalize",0.8,None),("Invert",0.1,None)),(("TranslateY",0.7,9),("AutoContrast",0.9,None)),]elifpolicy==AutoAugmentPolicy.SVHN:return[(("ShearX",0.9,4),("Invert",0.2,None)),(("ShearY",0.9,8),("Invert",0.7,None)),(("Equalize",0.6,None),("Solarize",0.6,6)),(("Invert",0.9,None),("Equalize",0.6,None)),(("Equalize",0.6,None),("Rotate",0.9,3)),(("ShearX",0.9,4),("AutoContrast",0.8,None)),(("ShearY",0.9,8),("Invert",0.4,None)),(("ShearY",0.9,5),("Solarize",0.2,6)),(("Invert",0.9,None),("AutoContrast",0.8,None)),(("Equalize",0.6,None),("Rotate",0.9,3)),(("ShearX",0.9,4),("Solarize",0.3,3)),(("ShearY",0.8,8),("Invert",0.7,None)),(("Equalize",0.9,None),("TranslateY",0.6,6)),(("Invert",0.9,None),("Equalize",0.6,None)),(("Contrast",0.3,3),("Rotate",0.8,4)),(("Invert",0.8,None),("TranslateY",0.0,2)),(("ShearY",0.7,6),("Solarize",0.4,8)),(("Invert",0.6,None),("Rotate",0.8,4)),(("ShearY",0.3,7),("TranslateX",0.9,3)),(("ShearX",0.1,6),("Invert",0.6,None)),(("Solarize",0.7,2),("TranslateY",0.6,7)),(("ShearY",0.8,4),("Invert",0.8,None)),(("ShearX",0.7,9),("TranslateY",0.8,3)),(("ShearY",0.8,5),("AutoContrast",0.7,None)),(("ShearX",0.7,2),("Invert",0.1,None)),]else:raiseValueError(f"The provided policy {policy} is not recognized.")def_augmentation_space(self,num_bins:int,image_size:Tuple[int,int])->Dict[str,Tuple[Tensor,bool]]:return{# op_name: (magnitudes, signed)"ShearX":(torch.linspace(0.0,0.3,num_bins),True),"ShearY":(torch.linspace(0.0,0.3,num_bins),True),"TranslateX":(torch.linspace(0.0,150.0/331.0*image_size[1],num_bins),True),"TranslateY":(torch.linspace(0.0,150.0/331.0*image_size[0],num_bins),True),"Rotate":(torch.linspace(0.0,30.0,num_bins),True),"Brightness":(torch.linspace(0.0,0.9,num_bins),True),"Color":(torch.linspace(0.0,0.9,num_bins),True),"Contrast":(torch.linspace(0.0,0.9,num_bins),True),"Sharpness":(torch.linspace(0.0,0.9,num_bins),True),"Posterize":(8-(torch.arange(num_bins)/((num_bins-1)/4)).round().int(),False),"Solarize":(torch.linspace(255.0,0.0,num_bins),False),"AutoContrast":(torch.tensor(0.0),False),"Equalize":(torch.tensor(0.0),False),"Invert":(torch.tensor(0.0),False),}
[docs]@staticmethoddefget_params(transform_num:int)->Tuple[int,Tensor,Tensor]:"""Get parameters for autoaugment transformation Returns: params required by the autoaugment transformation """policy_id=int(torch.randint(transform_num,(1,)).item())probs=torch.rand((2,))signs=torch.randint(2,(2,))returnpolicy_id,probs,signs
[docs]defforward(self,img:Tensor)->Tensor:""" img (PIL Image or Tensor): Image to be transformed. Returns: PIL Image or Tensor: AutoAugmented image. """fill=self.fillchannels,height,width=F.get_dimensions(img)ifisinstance(img,Tensor):ifisinstance(fill,(int,float)):fill=[float(fill)]*channelseliffillisnotNone:fill=[float(f)forfinfill]transform_id,probs,signs=self.get_params(len(self.policies))op_meta=self._augmentation_space(10,(height,width))fori,(op_name,p,magnitude_id)inenumerate(self.policies[transform_id]):ifprobs[i]<=p:magnitudes,signed=op_meta[op_name]magnitude=float(magnitudes[magnitude_id].item())ifmagnitude_idisnotNoneelse0.0ifsignedandsigns[i]==0:magnitude*=-1.0img=_apply_op(img,op_name,magnitude,interpolation=self.interpolation,fill=fill)returnimg
[docs]classRandAugment(torch.nn.Module):r"""RandAugment data augmentation method based on `"RandAugment: Practical automated data augmentation with a reduced search space" <https://arxiv.org/abs/1909.13719>`_. If the image is torch Tensor, it should be of type torch.uint8, and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". Args: num_ops (int): Number of augmentation transformations to apply sequentially. magnitude (int): Magnitude for all the transformations. num_magnitude_bins (int): The number of different magnitude values. interpolation (InterpolationMode): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. fill (sequence or number, optional): Pixel fill value for the area outside the transformed image. If given a number, the value is used for all bands respectively. """def__init__(self,num_ops:int=2,magnitude:int=9,num_magnitude_bins:int=31,interpolation:InterpolationMode=InterpolationMode.NEAREST,fill:Optional[List[float]]=None,)->None:super().__init__()self.num_ops=num_opsself.magnitude=magnitudeself.num_magnitude_bins=num_magnitude_binsself.interpolation=interpolationself.fill=filldef_augmentation_space(self,num_bins:int,image_size:Tuple[int,int])->Dict[str,Tuple[Tensor,bool]]:return{# op_name: (magnitudes, signed)"Identity":(torch.tensor(0.0),False),"ShearX":(torch.linspace(0.0,0.3,num_bins),True),"ShearY":(torch.linspace(0.0,0.3,num_bins),True),"TranslateX":(torch.linspace(0.0,150.0/331.0*image_size[1],num_bins),True),"TranslateY":(torch.linspace(0.0,150.0/331.0*image_size[0],num_bins),True),"Rotate":(torch.linspace(0.0,30.0,num_bins),True),"Brightness":(torch.linspace(0.0,0.9,num_bins),True),"Color":(torch.linspace(0.0,0.9,num_bins),True),"Contrast":(torch.linspace(0.0,0.9,num_bins),True),"Sharpness":(torch.linspace(0.0,0.9,num_bins),True),"Posterize":(8-(torch.arange(num_bins)/((num_bins-1)/4)).round().int(),False),"Solarize":(torch.linspace(255.0,0.0,num_bins),False),"AutoContrast":(torch.tensor(0.0),False),"Equalize":(torch.tensor(0.0),False),}
[docs]defforward(self,img:Tensor)->Tensor:""" img (PIL Image or Tensor): Image to be transformed. Returns: PIL Image or Tensor: Transformed image. """fill=self.fillchannels,height,width=F.get_dimensions(img)ifisinstance(img,Tensor):ifisinstance(fill,(int,float)):fill=[float(fill)]*channelseliffillisnotNone:fill=[float(f)forfinfill]op_meta=self._augmentation_space(self.num_magnitude_bins,(height,width))for_inrange(self.num_ops):op_index=int(torch.randint(len(op_meta),(1,)).item())op_name=list(op_meta.keys())[op_index]magnitudes,signed=op_meta[op_name]magnitude=float(magnitudes[self.magnitude].item())ifmagnitudes.ndim>0else0.0ifsignedandtorch.randint(2,(1,)):magnitude*=-1.0img=_apply_op(img,op_name,magnitude,interpolation=self.interpolation,fill=fill)returnimg
[docs]classTrivialAugmentWide(torch.nn.Module):r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`_. If the image is torch Tensor, it should be of type torch.uint8, and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". Args: num_magnitude_bins (int): The number of different magnitude values. interpolation (InterpolationMode): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. fill (sequence or number, optional): Pixel fill value for the area outside the transformed image. If given a number, the value is used for all bands respectively. """def__init__(self,num_magnitude_bins:int=31,interpolation:InterpolationMode=InterpolationMode.NEAREST,fill:Optional[List[float]]=None,)->None:super().__init__()self.num_magnitude_bins=num_magnitude_binsself.interpolation=interpolationself.fill=filldef_augmentation_space(self,num_bins:int)->Dict[str,Tuple[Tensor,bool]]:return{# op_name: (magnitudes, signed)"Identity":(torch.tensor(0.0),False),"ShearX":(torch.linspace(0.0,0.99,num_bins),True),"ShearY":(torch.linspace(0.0,0.99,num_bins),True),"TranslateX":(torch.linspace(0.0,32.0,num_bins),True),"TranslateY":(torch.linspace(0.0,32.0,num_bins),True),"Rotate":(torch.linspace(0.0,135.0,num_bins),True),"Brightness":(torch.linspace(0.0,0.99,num_bins),True),"Color":(torch.linspace(0.0,0.99,num_bins),True),"Contrast":(torch.linspace(0.0,0.99,num_bins),True),"Sharpness":(torch.linspace(0.0,0.99,num_bins),True),"Posterize":(8-(torch.arange(num_bins)/((num_bins-1)/6)).round().int(),False),"Solarize":(torch.linspace(255.0,0.0,num_bins),False),"AutoContrast":(torch.tensor(0.0),False),"Equalize":(torch.tensor(0.0),False),}
[docs]defforward(self,img:Tensor)->Tensor:""" img (PIL Image or Tensor): Image to be transformed. Returns: PIL Image or Tensor: Transformed image. """fill=self.fillchannels,height,width=F.get_dimensions(img)ifisinstance(img,Tensor):ifisinstance(fill,(int,float)):fill=[float(fill)]*channelseliffillisnotNone:fill=[float(f)forfinfill]op_meta=self._augmentation_space(self.num_magnitude_bins)op_index=int(torch.randint(len(op_meta),(1,)).item())op_name=list(op_meta.keys())[op_index]magnitudes,signed=op_meta[op_name]magnitude=(float(magnitudes[torch.randint(len(magnitudes),(1,),dtype=torch.long)].item())ifmagnitudes.ndim>0else0.0)ifsignedandtorch.randint(2,(1,)):magnitude*=-1.0return_apply_op(img,op_name,magnitude,interpolation=self.interpolation,fill=fill)
[docs]classAugMix(torch.nn.Module):r"""AugMix data augmentation method based on `"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" <https://arxiv.org/abs/1912.02781>`_. If the image is torch Tensor, it should be of type torch.uint8, and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". Args: severity (int): The severity of base augmentation operators. Default is ``3``. mixture_width (int): The number of augmentation chains. Default is ``3``. chain_depth (int): The depth of augmentation chains. A negative value denotes stochastic depth sampled from the interval [1, 3]. Default is ``-1``. alpha (float): The hyperparameter for the probability distributions. Default is ``1.0``. all_ops (bool): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``. interpolation (InterpolationMode): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. fill (sequence or number, optional): Pixel fill value for the area outside the transformed image. If given a number, the value is used for all bands respectively. """def__init__(self,severity:int=3,mixture_width:int=3,chain_depth:int=-1,alpha:float=1.0,all_ops:bool=True,interpolation:InterpolationMode=InterpolationMode.BILINEAR,fill:Optional[List[float]]=None,)->None:super().__init__()self._PARAMETER_MAX=10ifnot(1<=severity<=self._PARAMETER_MAX):raiseValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.")self.severity=severityself.mixture_width=mixture_widthself.chain_depth=chain_depthself.alpha=alphaself.all_ops=all_opsself.interpolation=interpolationself.fill=filldef_augmentation_space(self,num_bins:int,image_size:Tuple[int,int])->Dict[str,Tuple[Tensor,bool]]:s={# op_name: (magnitudes, signed)"ShearX":(torch.linspace(0.0,0.3,num_bins),True),"ShearY":(torch.linspace(0.0,0.3,num_bins),True),"TranslateX":(torch.linspace(0.0,image_size[1]/3.0,num_bins),True),"TranslateY":(torch.linspace(0.0,image_size[0]/3.0,num_bins),True),"Rotate":(torch.linspace(0.0,30.0,num_bins),True),"Posterize":(4-(torch.arange(num_bins)/((num_bins-1)/4)).round().int(),False),"Solarize":(torch.linspace(255.0,0.0,num_bins),False),"AutoContrast":(torch.tensor(0.0),False),"Equalize":(torch.tensor(0.0),False),}ifself.all_ops:s.update({"Brightness":(torch.linspace(0.0,0.9,num_bins),True),"Color":(torch.linspace(0.0,0.9,num_bins),True),"Contrast":(torch.linspace(0.0,0.9,num_bins),True),"Sharpness":(torch.linspace(0.0,0.9,num_bins),True),})returns@torch.jit.unuseddef_pil_to_tensor(self,img)->Tensor:returnF.pil_to_tensor(img)@torch.jit.unuseddef_tensor_to_pil(self,img:Tensor):returnF.to_pil_image(img)def_sample_dirichlet(self,params:Tensor)->Tensor:# Must be on a separate method so that we can overwrite it in tests.returntorch._sample_dirichlet(params)
[docs]defforward(self,orig_img:Tensor)->Tensor:""" img (PIL Image or Tensor): Image to be transformed. Returns: PIL Image or Tensor: Transformed image. """fill=self.fillchannels,height,width=F.get_dimensions(orig_img)ifisinstance(orig_img,Tensor):img=orig_imgifisinstance(fill,(int,float)):fill=[float(fill)]*channelseliffillisnotNone:fill=[float(f)forfinfill]else:img=self._pil_to_tensor(orig_img)op_meta=self._augmentation_space(self._PARAMETER_MAX,(height,width))orig_dims=list(img.shape)batch=img.view([1]*max(4-img.ndim,0)+orig_dims)batch_dims=[batch.size(0)]+[1]*(batch.ndim-1)# Sample the beta weights for combining the original and augmented image. To get Beta, we use a Dirichlet# with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of augmented image.m=self._sample_dirichlet(torch.tensor([self.alpha,self.alpha],device=batch.device).expand(batch_dims[0],-1))# Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images.combined_weights=self._sample_dirichlet(torch.tensor([self.alpha]*self.mixture_width,device=batch.device).expand(batch_dims[0],-1))*m[:,1].view([batch_dims[0],-1])mix=m[:,0].view(batch_dims)*batchforiinrange(self.mixture_width):aug=batchdepth=self.chain_depthifself.chain_depth>0elseint(torch.randint(low=1,high=4,size=(1,)).item())for_inrange(depth):op_index=int(torch.randint(len(op_meta),(1,)).item())op_name=list(op_meta.keys())[op_index]magnitudes,signed=op_meta[op_name]magnitude=(float(magnitudes[torch.randint(self.severity,(1,),dtype=torch.long)].item())ifmagnitudes.ndim>0else0.0)ifsignedandtorch.randint(2,(1,)):magnitude*=-1.0aug=_apply_op(aug,op_name,magnitude,interpolation=self.interpolation,fill=fill)mix.add_(combined_weights[:,i].view(batch_dims)*aug)mix=mix.view(orig_dims).to(dtype=img.dtype)ifnotisinstance(orig_img,Tensor):returnself._tensor_to_pil(mix)returnmix
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.