[docs]classMultiLabelConfusionMatrix(Metric):"""Calculates a confusion matrix for multi-labelled, multi-class data. - ``update`` must receive output of the form ``(y_pred, y)``. - `y_pred` must contain 0s and 1s and has the following shape (batch_size, num_classes, ...). For example, `y_pred[i, j]` = 1 denotes that the j'th class is one of the labels of the i'th sample as predicted. - `y` should have the following shape (batch_size, num_classes, ...) with 0s and 1s. For example, `y[i, j]` = 1 denotes that the j'th class is one of the labels of the i'th sample according to the ground truth. - both `y` and `y_pred` must be torch Tensors having any of the following types: {torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}. They must have the same dimensions. - The confusion matrix 'M' is of dimension (num_classes, 2, 2). * M[i, 0, 0] corresponds to count/rate of true negatives of class i * M[i, 0, 1] corresponds to count/rate of false positives of class i * M[i, 1, 0] corresponds to count/rate of false negatives of class i * M[i, 1, 1] corresponds to count/rate of true positives of class i - The classes present in M are indexed as 0, ... , num_classes-1 as can be inferred from above. Args: num_classes: Number of classes, should be > 1. output_transform: a callable that is used to transform the :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. device: specifies which device updates are accumulated on. Setting the metric's device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By default, CPU. normalized: whether to normalize confusion matrix by its sum or not. Example: For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. .. include:: defaults.rst :start-after: :orphan: .. testcode:: metric = MultiLabelConfusionMatrix(num_classes=3) metric.attach(default_evaluator, "mlcm") y_true = torch.tensor([ [0, 0, 1], [0, 0, 0], [0, 0, 0], [1, 0, 0], [0, 1, 1], ]) y_pred = torch.tensor([ [1, 1, 0], [1, 0, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], ]) state = default_evaluator.run([[y_pred, y_true]]) print(state.metrics["mlcm"]) .. testoutput:: tensor([[[0, 4], [0, 1]], [[3, 1], [0, 1]], [[1, 2], [2, 0]]]) .. versionadded:: 0.4.5 """def__init__(self,num_classes:int,output_transform:Callable=lambdax:x,device:Union[str,torch.device]=torch.device("cpu"),normalized:bool=False,):ifnum_classes<=1:raiseValueError("Argument num_classes needs to be > 1")self.num_classes=num_classesself._num_examples=0self.normalized=normalizedsuper(MultiLabelConfusionMatrix,self).__init__(output_transform=output_transform,device=device)
[docs]@sync_all_reduce("confusion_matrix","_num_examples")defcompute(self)->torch.Tensor:ifself._num_examples==0:raiseNotComputableError("Confusion matrix must have at least one example before it can be computed.")ifself.normalized:conf=self.confusion_matrix.to(dtype=torch.float64)sums=conf.sum(dim=(1,2))returnconf/sums[:,None,None]returnself.confusion_matrix
def_check_input(self,output:Sequence[torch.Tensor])->None:y_pred,y=output[0].detach(),output[1].detach()ify_pred.ndimension()<2:raiseValueError(f"y_pred must at least have shape (batch_size, num_classes (currently set to {self.num_classes}), ...)")ify.ndimension()<2:raiseValueError(f"y must at least have shape (batch_size, num_classes (currently set to {self.num_classes}), ...)")ify_pred.shape[0]!=y.shape[0]:raiseValueError(f"y_pred and y have different batch size: {y_pred.shape[0]} vs {y.shape[0]}")ify_pred.shape[1]!=self.num_classes:raiseValueError(f"y_pred does not have correct number of classes: {y_pred.shape[1]} vs {self.num_classes}")ify.shape[1]!=self.num_classes:raiseValueError(f"y does not have correct number of classes: {y.shape[1]} vs {self.num_classes}")ify.shape!=y_pred.shape:raiseValueError("y and y_pred shapes must match.")valid_types=(torch.uint8,torch.int8,torch.int16,torch.int32,torch.int64)ify_pred.dtypenotinvalid_types:raiseValueError(f"y_pred must be of any type: {valid_types}")ify.dtypenotinvalid_types:raiseValueError(f"y must be of any type: {valid_types}")ifnottorch.equal(y_pred,y_pred**2):raiseValueError("y_pred must be a binary tensor")ifnottorch.equal(y,y**2):raiseValueError("y must be a binary tensor")