[docs]classSSIM(Metric):""" Computes Structural Similarity Index Measure - ``update`` must receive output of the form ``(y_pred, y)``. They have to be of the same type. Valid :class:`torch.dtype` are the following: - on CPU: `torch.float32`, `torch.float64`. - on CUDA: `torch.float16`, `torch.bfloat16`, `torch.float32`, `torch.float64`. Args: data_range: Range of the image. Typically, ``1.0`` or ``255``. kernel_size: Size of the kernel. Default: 11 sigma: Standard deviation of the gaussian kernel. Argument is used if ``gaussian=True``. Default: 1.5 k1: Parameter of SSIM. Default: 0.01 k2: Parameter of SSIM. Default: 0.03 gaussian: ``True`` to use gaussian kernel, ``False`` to use uniform kernel output_transform: A callable that is used to transform the :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. device: specifies which device updates are accumulated on. Setting the metric's device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By default, CPU. skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)`` Alternatively, ``output_transform`` can be used to handle this. ndims: Number of dimensions of the input image: 2d or 3d. Accepted values: 2, 3. Default: 2 Examples: To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. The output of the engine's ``process_function`` needs to be in the format of ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added to the metric to transform the output into the form expected by the metric. ``y_pred`` and ``y`` can be un-normalized or normalized image tensors. Depending on that, the user might need to adjust ``data_range``. ``y_pred`` and ``y`` should have the same shape. For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. .. include:: defaults.rst :start-after: :orphan: .. testcode:: metric = SSIM(data_range=1.0) metric.attach(default_evaluator, 'ssim') preds = torch.rand([4, 3, 16, 16]) target = preds * 0.75 state = default_evaluator.run([[preds, target]]) print(state.metrics['ssim']) .. testoutput:: 0.9218971... .. versionadded:: 0.4.2 .. versionchanged:: 0.5.1 ``skip_unrolling`` argument is added. .. versionchanged:: 0.5.2 ``ndims`` argument is added. """_state_dict_all_req_keys=("_sum_of_ssim","_num_examples","_kernel")def__init__(self,data_range:Union[int,float],kernel_size:Union[int,Sequence[int]]=11,sigma:Union[float,Sequence[float]]=1.5,k1:float=0.01,k2:float=0.03,gaussian:bool=True,output_transform:Callable=lambdax:x,device:Union[str,torch.device]=torch.device("cpu"),skip_unrolling:bool=False,ndims:int=2,):ifndimsnotin(2,3):raiseValueError(f"Expected ndims to be 2 or 3. Got {ndims}.")ifisinstance(kernel_size,int):self.kernel_size:Sequence[int]=[kernel_sizefor_inrange(ndims)]elifisinstance(kernel_size,Sequence):iflen(kernel_size)!=ndims:raiseValueError(f"Expected kernel_size to have length of {ndims}. Got {len(kernel_size)}.")self.kernel_size=kernel_sizeelse:raiseValueError(f"Argument kernel_size should be either int or a sequence of int of length {ndims}.")ifisinstance(sigma,float):self.sigma:Sequence[float]=[sigmafor_inrange(ndims)]elifisinstance(sigma,Sequence):iflen(sigma)!=ndims:raiseValueError(f"Expected sigma to have length of {ndims}. Got {len(sigma)}.")self.sigma=sigmaelse:raiseValueError(f"Argument sigma should be either float or a sequence of float of length {ndims}.")ifany(x%2==0orx<=0forxinself.kernel_size):raiseValueError(f"Expected kernel_size to have odd positive number. Got {kernel_size}.")ifany(y<=0foryinself.sigma):raiseValueError(f"Expected sigma to have positive number. Got {sigma}.")super(SSIM,self).__init__(output_transform=output_transform,device=device,skip_unrolling=skip_unrolling)self.gaussian=gaussianself.data_range=data_rangeself.c1=(k1*data_range)**2self.c2=(k2*data_range)**2self.pad_h=(self.kernel_size[0]-1)//2self.pad_w=(self.kernel_size[1]-1)//2self.pad_d=Noneself.ndims=ndimsifself.ndims==3:self.pad_d=(self.kernel_size[2]-1)//2self._kernel_nd=self._gaussian_or_uniform_kernel(kernel_size=self.kernel_size,sigma=self.sigma,ndims=self.ndims)self._kernel:Optional[torch.Tensor]=None
def_uniform(self,kernel_size:int)->torch.Tensor:kernel=torch.zeros(kernel_size,device=self._device)start_uniform_index=max(kernel_size//2-2,0)end_uniform_index=min(kernel_size//2+3,kernel_size)min_,max_=-2.5,2.5kernel[start_uniform_index:end_uniform_index]=1/(max_-min_)returnkernel# (kernel_size)def_gaussian(self,kernel_size:int,sigma:float)->torch.Tensor:ksize_half=(kernel_size-1)*0.5kernel=torch.linspace(-ksize_half,ksize_half,steps=kernel_size,device=self._device)gauss=torch.exp(-0.5*(kernel/sigma).pow(2))returngauss/gauss.sum()# (kernel_size)def_gaussian_or_uniform_kernel(self,kernel_size:Sequence[int],sigma:Sequence[float],ndims:int)->torch.Tensor:ifself.gaussian:kernel_x=self._gaussian(kernel_size[0],sigma[0])kernel_y=self._gaussian(kernel_size[1],sigma[1])ifndims==3:kernel_z=self._gaussian(kernel_size[2],sigma[2])else:kernel_x=self._uniform(kernel_size[0])kernel_y=self._uniform(kernel_size[1])ifndims==3:kernel_z=self._uniform(kernel_size[2])result=(torch.einsum("i,j->ij",kernel_x,kernel_y)ifndims==2elsetorch.einsum("i,j,k->ijk",kernel_x,kernel_y,kernel_z))returnresultdef_check_type_and_shape(self,y_pred:torch.Tensor,y:torch.Tensor)->None:ify_pred.dtype!=y.dtype:raiseTypeError(f"Expected y_pred and y to have the same data type. Got y_pred: {y_pred.dtype} and y: {y.dtype}.")ify_pred.shape!=y.shape:raiseValueError(f"Expected y_pred and y to have the same shape. Got y_pred: {y_pred.shape} and y: {y.shape}.")# 2 dimensions are reserved for batch and channeliflen(y_pred.shape)-2!=self.ndimsorlen(y.shape)-2!=self.ndims:raiseValueError("Expected y_pred and y to have BxCxHxW or BxCxDxHxW shape. "f"Got y_pred: {y_pred.shape} and y: {y.shape}.")
[docs]@reinit__is_reduceddefupdate(self,output:Sequence[torch.Tensor])->None:y_pred,y=output[0].detach(),output[1].detach()self._check_type_and_shape(y_pred,y)# converts potential integer tensor to fpifnoty.is_floating_point():y=y.float()ifnoty_pred.is_floating_point():y_pred=y_pred.float()nb_channel=y_pred.size(1)ifself._kernelisNoneorself._kernel.shape[0]!=nb_channel:self._kernel=self._kernel_nd.expand(nb_channel,1,*[-1for_inrange(self.ndims)])ify_pred.device!=self._kernel.device:ifself._kernel.device==torch.device("cpu"):self._kernel=self._kernel.to(device=y_pred.device)elify_pred.device==torch.device("cpu"):warnings.warn("y_pred tensor is on cpu device but previous computation was on another device: "f"{self._kernel.device}. To avoid having a performance hit, please ensure that all ""y and y_pred tensors are on the same device.",)y_pred=y_pred.to(device=self._kernel.device)y=y.to(device=self._kernel.device)padding_shape=[self.pad_w,self.pad_w,self.pad_h,self.pad_h]ifself.ndims==3andself.pad_disnotNone:padding_shape.extend([self.pad_d,self.pad_d])y_pred=F.pad(y_pred,padding_shape,mode="reflect")y=F.pad(y,padding_shape,mode="reflect")ify_pred.dtype!=self._kernel.dtype:self._kernel=self._kernel.to(dtype=y_pred.dtype)input_list=[y_pred,y,y_pred*y_pred,y*y,y_pred*y]conv_op=F.conv3difself.ndims==3elseF.conv2doutputs=conv_op(torch.cat(input_list),self._kernel,groups=nb_channel)batch_size=y_pred.size(0)output_list=[outputs[x*batch_size:(x+1)*batch_size]forxinrange(len(input_list))]mu_pred_sq=output_list[0].pow(2)mu_target_sq=output_list[1].pow(2)mu_pred_target=output_list[0]*output_list[1]sigma_pred_sq=output_list[2]-mu_pred_sqsigma_target_sq=output_list[3]-mu_target_sqsigma_pred_target=output_list[4]-mu_pred_targeta1=2*mu_pred_target+self.c1a2=2*sigma_pred_target+self.c2b1=mu_pred_sq+mu_target_sq+self.c1b2=sigma_pred_sq+sigma_target_sq+self.c2ssim_idx=(a1*a2)/(b1*b2)# In case when ssim_idx can be MPS tensor and self._device is not MPS# self._double_dtype is float64.# As MPS does not support float64 we should set dtype to float32double_dtype=self._double_dtypeifssim_idx.device.type=="mps"andself._double_dtype==torch.float64:double_dtype=torch.float32# mean from all dimensions except batchself._sum_of_ssim+=(torch.mean(ssim_idx,list(range(1,2+self.ndims)),dtype=double_dtype).sum().to(device=self._device))self._num_examples+=y.shape[0]
[docs]@sync_all_reduce("_sum_of_ssim","_num_examples")defcompute(self)->float:ifself._num_examples==0:raiseNotComputableError("SSIM must have at least one example before it can be computed.")return(self._sum_of_ssim/self._num_examples).item()