Source code for torchvision.models.video.swin_transformer
# Modified from 2d Swin Transformers in torchvision:# https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.pyfromfunctoolsimportpartialfromtypingimportAny,Callable,List,Optional,Tupleimporttorchimporttorch.nn.functionalasFfromtorchimportnn,Tensorfrom...transforms._presetsimportVideoClassificationfrom...utilsimport_log_api_usage_oncefrom.._apiimportregister_model,Weights,WeightsEnumfrom.._metaimport_KINETICS400_CATEGORIESfrom.._utilsimport_ovewrite_named_param,handle_legacy_interfacefrom..swin_transformerimportPatchMerging,SwinTransformerBlock__all__=["SwinTransformer3d","Swin3D_T_Weights","Swin3D_S_Weights","Swin3D_B_Weights","swin3d_t","swin3d_s","swin3d_b",]def_get_window_and_shift_size(shift_size:List[int],size_dhw:List[int],window_size:List[int])->Tuple[List[int],List[int]]:foriinrange(3):ifsize_dhw[i]<=window_size[i]:# In this case, window_size will adapt to the input size, and no need to shiftwindow_size[i]=size_dhw[i]shift_size[i]=0returnwindow_size,shift_sizetorch.fx.wrap("_get_window_and_shift_size")def_get_relative_position_bias(relative_position_bias_table:torch.Tensor,relative_position_index:torch.Tensor,window_size:List[int])->Tensor:window_vol=window_size[0]*window_size[1]*window_size[2]# In 3d case we flatten the relative_position_biasrelative_position_bias=relative_position_bias_table[relative_position_index[:window_vol,:window_vol].flatten()# type: ignore[index]]relative_position_bias=relative_position_bias.view(window_vol,window_vol,-1)relative_position_bias=relative_position_bias.permute(2,0,1).contiguous().unsqueeze(0)returnrelative_position_biastorch.fx.wrap("_get_relative_position_bias")def_compute_pad_size_3d(size_dhw:Tuple[int,int,int],patch_size:Tuple[int,int,int])->Tuple[int,int,int]:pad_size=[(patch_size[i]-size_dhw[i]%patch_size[i])%patch_size[i]foriinrange(3)]returnpad_size[0],pad_size[1],pad_size[2]torch.fx.wrap("_compute_pad_size_3d")def_compute_attention_mask_3d(x:Tensor,size_dhw:Tuple[int,int,int],window_size:Tuple[int,int,int],shift_size:Tuple[int,int,int],)->Tensor:# generate attention maskattn_mask=x.new_zeros(*size_dhw)num_windows=(size_dhw[0]//window_size[0])*(size_dhw[1]//window_size[1])*(size_dhw[2]//window_size[2])slices=[((0,-window_size[i]),(-window_size[i],-shift_size[i]),(-shift_size[i],None),)foriinrange(3)]count=0fordinslices[0]:forhinslices[1]:forwinslices[2]:attn_mask[d[0]:d[1],h[0]:h[1],w[0]:w[1]]=countcount+=1# Partition window on attn_maskattn_mask=attn_mask.view(size_dhw[0]//window_size[0],window_size[0],size_dhw[1]//window_size[1],window_size[1],size_dhw[2]//window_size[2],window_size[2],)attn_mask=attn_mask.permute(0,2,4,1,3,5).reshape(num_windows,window_size[0]*window_size[1]*window_size[2])attn_mask=attn_mask.unsqueeze(1)-attn_mask.unsqueeze(2)attn_mask=attn_mask.masked_fill(attn_mask!=0,float(-100.0)).masked_fill(attn_mask==0,float(0.0))returnattn_masktorch.fx.wrap("_compute_attention_mask_3d")defshifted_window_attention_3d(input:Tensor,qkv_weight:Tensor,proj_weight:Tensor,relative_position_bias:Tensor,window_size:List[int],num_heads:int,shift_size:List[int],attention_dropout:float=0.0,dropout:float=0.0,qkv_bias:Optional[Tensor]=None,proj_bias:Optional[Tensor]=None,training:bool=True,)->Tensor:""" Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: input (Tensor[B, T, H, W, C]): The input tensor, 5-dimensions. qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value. proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection. relative_position_bias (Tensor): The learned relative position bias added to attention. window_size (List[int]): 3-dimensions window size, T, H, W . num_heads (int): Number of attention heads. shift_size (List[int]): Shift size for shifted window attention (T, H, W). attention_dropout (float): Dropout ratio of attention weight. Default: 0.0. dropout (float): Dropout ratio of output. Default: 0.0. qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None. training (bool, optional): Training flag used by the dropout parameters. Default: True. Returns: Tensor[B, T, H, W, C]: The output tensor after shifted window attention. """b,t,h,w,c=input.shape# pad feature maps to multiples of window sizepad_size=_compute_pad_size_3d((t,h,w),(window_size[0],window_size[1],window_size[2]))x=F.pad(input,(0,0,0,pad_size[2],0,pad_size[1],0,pad_size[0]))_,tp,hp,wp,_=x.shapepadded_size=(tp,hp,wp)# cyclic shiftifsum(shift_size)>0:x=torch.roll(x,shifts=(-shift_size[0],-shift_size[1],-shift_size[2]),dims=(1,2,3))# partition windowsnum_windows=((padded_size[0]//window_size[0])*(padded_size[1]//window_size[1])*(padded_size[2]//window_size[2]))x=x.view(b,padded_size[0]//window_size[0],window_size[0],padded_size[1]//window_size[1],window_size[1],padded_size[2]//window_size[2],window_size[2],c,)x=x.permute(0,1,3,5,2,4,6,7).reshape(b*num_windows,window_size[0]*window_size[1]*window_size[2],c)# B*nW, Wd*Wh*Ww, C# multi-head attentionqkv=F.linear(x,qkv_weight,qkv_bias)qkv=qkv.reshape(x.size(0),x.size(1),3,num_heads,c//num_heads).permute(2,0,3,1,4)q,k,v=qkv[0],qkv[1],qkv[2]q=q*(c//num_heads)**-0.5attn=q.matmul(k.transpose(-2,-1))# add relative position biasattn=attn+relative_position_biasifsum(shift_size)>0:# generate attention mask to handle shifted windows with varying sizeattn_mask=_compute_attention_mask_3d(x,(padded_size[0],padded_size[1],padded_size[2]),(window_size[0],window_size[1],window_size[2]),(shift_size[0],shift_size[1],shift_size[2]),)attn=attn.view(x.size(0)//num_windows,num_windows,num_heads,x.size(1),x.size(1))attn=attn+attn_mask.unsqueeze(1).unsqueeze(0)attn=attn.view(-1,num_heads,x.size(1),x.size(1))attn=F.softmax(attn,dim=-1)attn=F.dropout(attn,p=attention_dropout,training=training)x=attn.matmul(v).transpose(1,2).reshape(x.size(0),x.size(1),c)x=F.linear(x,proj_weight,proj_bias)x=F.dropout(x,p=dropout,training=training)# reverse windowsx=x.view(b,padded_size[0]//window_size[0],padded_size[1]//window_size[1],padded_size[2]//window_size[2],window_size[0],window_size[1],window_size[2],c,)x=x.permute(0,1,4,2,5,3,6,7).reshape(b,tp,hp,wp,c)# reverse cyclic shiftifsum(shift_size)>0:x=torch.roll(x,shifts=(shift_size[0],shift_size[1],shift_size[2]),dims=(1,2,3))# unpad featuresx=x[:,:t,:h,:w,:].contiguous()returnxtorch.fx.wrap("shifted_window_attention_3d")classShiftedWindowAttention3d(nn.Module):""" See :func:`shifted_window_attention_3d`. """def__init__(self,dim:int,window_size:List[int],shift_size:List[int],num_heads:int,qkv_bias:bool=True,proj_bias:bool=True,attention_dropout:float=0.0,dropout:float=0.0,)->None:super().__init__()iflen(window_size)!=3orlen(shift_size)!=3:raiseValueError("window_size and shift_size must be of length 2")self.window_size=window_size# Wd, Wh, Wwself.shift_size=shift_sizeself.num_heads=num_headsself.attention_dropout=attention_dropoutself.dropout=dropoutself.qkv=nn.Linear(dim,dim*3,bias=qkv_bias)self.proj=nn.Linear(dim,dim,bias=proj_bias)self.define_relative_position_bias_table()self.define_relative_position_index()defdefine_relative_position_bias_table(self)->None:# define a parameter table of relative position biasself.relative_position_bias_table=nn.Parameter(torch.zeros((2*self.window_size[0]-1)*(2*self.window_size[1]-1)*(2*self.window_size[2]-1),self.num_heads,))# 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nHnn.init.trunc_normal_(self.relative_position_bias_table,std=0.02)defdefine_relative_position_index(self)->None:# get pair-wise relative position index for each token inside the windowcoords_dhw=[torch.arange(self.window_size[i])foriinrange(3)]coords=torch.stack(torch.meshgrid(coords_dhw[0],coords_dhw[1],coords_dhw[2],indexing="ij"))# 3, Wd, Wh, Wwcoords_flatten=torch.flatten(coords,1)# 3, Wd*Wh*Wwrelative_coords=coords_flatten[:,:,None]-coords_flatten[:,None,:]# 3, Wd*Wh*Ww, Wd*Wh*Wwrelative_coords=relative_coords.permute(1,2,0).contiguous()# Wd*Wh*Ww, Wd*Wh*Ww, 3relative_coords[:,:,0]+=self.window_size[0]-1# shift to start from 0relative_coords[:,:,1]+=self.window_size[1]-1relative_coords[:,:,2]+=self.window_size[2]-1relative_coords[:,:,0]*=(2*self.window_size[1]-1)*(2*self.window_size[2]-1)relative_coords[:,:,1]*=2*self.window_size[2]-1# We don't flatten the relative_position_index here in 3d case.relative_position_index=relative_coords.sum(-1)# Wd*Wh*Ww, Wd*Wh*Wwself.register_buffer("relative_position_index",relative_position_index)defget_relative_position_bias(self,window_size:List[int])->torch.Tensor:return_get_relative_position_bias(self.relative_position_bias_table,self.relative_position_index,window_size)# type: ignoredefforward(self,x:Tensor)->Tensor:_,t,h,w,_=x.shapesize_dhw=[t,h,w]window_size,shift_size=self.window_size.copy(),self.shift_size.copy()# Handle case where window_size is larger than the input tensorwindow_size,shift_size=_get_window_and_shift_size(shift_size,size_dhw,window_size)relative_position_bias=self.get_relative_position_bias(window_size)returnshifted_window_attention_3d(x,self.qkv.weight,self.proj.weight,relative_position_bias,window_size,self.num_heads,shift_size=shift_size,attention_dropout=self.attention_dropout,dropout=self.dropout,qkv_bias=self.qkv.bias,proj_bias=self.proj.bias,training=self.training,)# Modified from:# https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/mmaction/models/backbones/swin_transformer.pyclassPatchEmbed3d(nn.Module):"""Video to Patch Embedding. Args: patch_size (List[int]): Patch token size. in_channels (int): Number of input channels. Default: 3 embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """def__init__(self,patch_size:List[int],in_channels:int=3,embed_dim:int=96,norm_layer:Optional[Callable[...,nn.Module]]=None,)->None:super().__init__()_log_api_usage_once(self)self.tuple_patch_size=(patch_size[0],patch_size[1],patch_size[2])self.proj=nn.Conv3d(in_channels,embed_dim,kernel_size=self.tuple_patch_size,stride=self.tuple_patch_size,)ifnorm_layerisnotNone:self.norm=norm_layer(embed_dim)else:self.norm=nn.Identity()defforward(self,x:Tensor)->Tensor:"""Forward function."""# padding_,_,t,h,w=x.size()pad_size=_compute_pad_size_3d((t,h,w),self.tuple_patch_size)x=F.pad(x,(0,pad_size[2],0,pad_size[1],0,pad_size[0]))x=self.proj(x)# B C T Wh Wwx=x.permute(0,2,3,4,1)# B T Wh Ww Cifself.normisnotNone:x=self.norm(x)returnxclassSwinTransformer3d(nn.Module):""" Implements 3D Swin Transformer from the `"Video Swin Transformer" <https://arxiv.org/abs/2106.13230>`_ paper. Args: patch_size (List[int]): Patch size. embed_dim (int): Patch embedding dimension. depths (List(int)): Depth of each Swin Transformer layer. num_heads (List(int)): Number of attention heads in different layers. window_size (List[int]): Window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. dropout (float): Dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob (float): Stochastic depth rate. Default: 0.1. num_classes (int): Number of classes for classification head. Default: 400. norm_layer (nn.Module, optional): Normalization layer. Default: None. block (nn.Module, optional): SwinTransformer Block. Default: None. downsample_layer (nn.Module): Downsample layer (patch merging). Default: PatchMerging. patch_embed (nn.Module, optional): Patch Embedding layer. Default: None. """def__init__(self,patch_size:List[int],embed_dim:int,depths:List[int],num_heads:List[int],window_size:List[int],mlp_ratio:float=4.0,dropout:float=0.0,attention_dropout:float=0.0,stochastic_depth_prob:float=0.1,num_classes:int=400,norm_layer:Optional[Callable[...,nn.Module]]=None,block:Optional[Callable[...,nn.Module]]=None,downsample_layer:Callable[...,nn.Module]=PatchMerging,patch_embed:Optional[Callable[...,nn.Module]]=None,)->None:super().__init__()_log_api_usage_once(self)self.num_classes=num_classesifblockisNone:block=partial(SwinTransformerBlock,attn_layer=ShiftedWindowAttention3d)ifnorm_layerisNone:norm_layer=partial(nn.LayerNorm,eps=1e-5)ifpatch_embedisNone:patch_embed=PatchEmbed3d# split image into non-overlapping patchesself.patch_embed=patch_embed(patch_size=patch_size,embed_dim=embed_dim,norm_layer=norm_layer)self.pos_drop=nn.Dropout(p=dropout)layers:List[nn.Module]=[]total_stage_blocks=sum(depths)stage_block_id=0# build SwinTransformer blocksfori_stageinrange(len(depths)):stage:List[nn.Module]=[]dim=embed_dim*2**i_stagefori_layerinrange(depths[i_stage]):# adjust stochastic depth probability based on the depth of the stage blocksd_prob=stochastic_depth_prob*float(stage_block_id)/(total_stage_blocks-1)stage.append(block(dim,num_heads[i_stage],window_size=window_size,shift_size=[0ifi_layer%2==0elsew//2forwinwindow_size],mlp_ratio=mlp_ratio,dropout=dropout,attention_dropout=attention_dropout,stochastic_depth_prob=sd_prob,norm_layer=norm_layer,attn_layer=ShiftedWindowAttention3d,))stage_block_id+=1layers.append(nn.Sequential(*stage))# add patch merging layerifi_stage<(len(depths)-1):layers.append(downsample_layer(dim,norm_layer))self.features=nn.Sequential(*layers)self.num_features=embed_dim*2**(len(depths)-1)self.norm=norm_layer(self.num_features)self.avgpool=nn.AdaptiveAvgPool3d(1)self.head=nn.Linear(self.num_features,num_classes)forminself.modules():ifisinstance(m,nn.Linear):nn.init.trunc_normal_(m.weight,std=0.02)ifm.biasisnotNone:nn.init.zeros_(m.bias)defforward(self,x:Tensor)->Tensor:# x: B C T H Wx=self.patch_embed(x)# B _T _H _W Cx=self.pos_drop(x)x=self.features(x)# B _T _H _W Cx=self.norm(x)x=x.permute(0,4,1,2,3)# B, C, _T, _H, _Wx=self.avgpool(x)x=torch.flatten(x,1)x=self.head(x)returnxdef_swin_transformer3d(patch_size:List[int],embed_dim:int,depths:List[int],num_heads:List[int],window_size:List[int],stochastic_depth_prob:float,weights:Optional[WeightsEnum],progress:bool,**kwargs:Any,)->SwinTransformer3d:ifweightsisnotNone:_ovewrite_named_param(kwargs,"num_classes",len(weights.meta["categories"]))model=SwinTransformer3d(patch_size=patch_size,embed_dim=embed_dim,depths=depths,num_heads=num_heads,window_size=window_size,stochastic_depth_prob=stochastic_depth_prob,**kwargs,)ifweightsisnotNone:model.load_state_dict(weights.get_state_dict(progress=progress,check_hash=True))returnmodel_COMMON_META={"categories":_KINETICS400_CATEGORIES,"min_size":(1,1),"min_temporal_size":1,}
[docs]classSwin3D_T_Weights(WeightsEnum):KINETICS400_V1=Weights(url="https://download.pytorch.org/models/swin3d_t-7615ae03.pth",transforms=partial(VideoClassification,crop_size=(224,224),resize_size=(256,),mean=(0.4850,0.4560,0.4060),std=(0.2290,0.2240,0.2250),),meta={**_COMMON_META,"recipe":"https://github.com/SwinTransformer/Video-Swin-Transformer#kinetics-400","_docs":("The weights were ported from the paper. The accuracies are estimated on video-level ""with parameters `frame_rate=15`, `clips_per_video=12`, and `clip_len=32`"),"num_params":28158070,"_metrics":{"Kinetics-400":{"acc@1":77.715,"acc@5":93.519,}},"_ops":43.882,"_file_size":121.543,},)DEFAULT=KINETICS400_V1
[docs]classSwin3D_S_Weights(WeightsEnum):KINETICS400_V1=Weights(url="https://download.pytorch.org/models/swin3d_s-da41c237.pth",transforms=partial(VideoClassification,crop_size=(224,224),resize_size=(256,),mean=(0.4850,0.4560,0.4060),std=(0.2290,0.2240,0.2250),),meta={**_COMMON_META,"recipe":"https://github.com/SwinTransformer/Video-Swin-Transformer#kinetics-400","_docs":("The weights were ported from the paper. The accuracies are estimated on video-level ""with parameters `frame_rate=15`, `clips_per_video=12`, and `clip_len=32`"),"num_params":49816678,"_metrics":{"Kinetics-400":{"acc@1":79.521,"acc@5":94.158,}},"_ops":82.841,"_file_size":218.288,},)DEFAULT=KINETICS400_V1
[docs]classSwin3D_B_Weights(WeightsEnum):KINETICS400_V1=Weights(url="https://download.pytorch.org/models/swin3d_b_1k-24f7c7c6.pth",transforms=partial(VideoClassification,crop_size=(224,224),resize_size=(256,),mean=(0.4850,0.4560,0.4060),std=(0.2290,0.2240,0.2250),),meta={**_COMMON_META,"recipe":"https://github.com/SwinTransformer/Video-Swin-Transformer#kinetics-400","_docs":("The weights were ported from the paper. The accuracies are estimated on video-level ""with parameters `frame_rate=15`, `clips_per_video=12`, and `clip_len=32`"),"num_params":88048984,"_metrics":{"Kinetics-400":{"acc@1":79.427,"acc@5":94.386,}},"_ops":140.667,"_file_size":364.134,},)KINETICS400_IMAGENET22K_V1=Weights(url="https://download.pytorch.org/models/swin3d_b_22k-7c6ae6fa.pth",transforms=partial(VideoClassification,crop_size=(224,224),resize_size=(256,),mean=(0.4850,0.4560,0.4060),std=(0.2290,0.2240,0.2250),),meta={**_COMMON_META,"recipe":"https://github.com/SwinTransformer/Video-Swin-Transformer#kinetics-400","_docs":("The weights were ported from the paper. The accuracies are estimated on video-level ""with parameters `frame_rate=15`, `clips_per_video=12`, and `clip_len=32`"),"num_params":88048984,"_metrics":{"Kinetics-400":{"acc@1":81.643,"acc@5":95.574,}},"_ops":140.667,"_file_size":364.134,},)DEFAULT=KINETICS400_V1
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",Swin3D_T_Weights.KINETICS400_V1))defswin3d_t(*,weights:Optional[Swin3D_T_Weights]=None,progress:bool=True,**kwargs:Any)->SwinTransformer3d:""" Constructs a swin_tiny architecture from `Video Swin Transformer <https://arxiv.org/abs/2106.13230>`_. Args: weights (:class:`~torchvision.models.video.Swin3D_T_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.video.Swin3D_T_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.video.swin_transformer.SwinTransformer`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/video/swin_transformer.py>`_ for more details about this class. .. autoclass:: torchvision.models.video.Swin3D_T_Weights :members: """weights=Swin3D_T_Weights.verify(weights)return_swin_transformer3d(patch_size=[2,4,4],embed_dim=96,depths=[2,2,6,2],num_heads=[3,6,12,24],window_size=[8,7,7],stochastic_depth_prob=0.1,weights=weights,progress=progress,**kwargs,)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",Swin3D_S_Weights.KINETICS400_V1))defswin3d_s(*,weights:Optional[Swin3D_S_Weights]=None,progress:bool=True,**kwargs:Any)->SwinTransformer3d:""" Constructs a swin_small architecture from `Video Swin Transformer <https://arxiv.org/abs/2106.13230>`_. Args: weights (:class:`~torchvision.models.video.Swin3D_S_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.video.Swin3D_S_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.video.swin_transformer.SwinTransformer`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/video/swin_transformer.py>`_ for more details about this class. .. autoclass:: torchvision.models.video.Swin3D_S_Weights :members: """weights=Swin3D_S_Weights.verify(weights)return_swin_transformer3d(patch_size=[2,4,4],embed_dim=96,depths=[2,2,18,2],num_heads=[3,6,12,24],window_size=[8,7,7],stochastic_depth_prob=0.1,weights=weights,progress=progress,**kwargs,)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",Swin3D_B_Weights.KINETICS400_V1))defswin3d_b(*,weights:Optional[Swin3D_B_Weights]=None,progress:bool=True,**kwargs:Any)->SwinTransformer3d:""" Constructs a swin_base architecture from `Video Swin Transformer <https://arxiv.org/abs/2106.13230>`_. Args: weights (:class:`~torchvision.models.video.Swin3D_B_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.video.Swin3D_B_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.video.swin_transformer.SwinTransformer`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/video/swin_transformer.py>`_ for more details about this class. .. autoclass:: torchvision.models.video.Swin3D_B_Weights :members: """weights=Swin3D_B_Weights.verify(weights)return_swin_transformer3d(patch_size=[2,4,4],embed_dim=128,depths=[2,2,18,2],num_heads=[4,8,16,32],window_size=[8,7,7],stochastic_depth_prob=0.1,weights=weights,progress=progress,**kwargs,)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.