sigmoid_focal_loss¶
-
torchvision.ops.
sigmoid_focal_loss
(inputs: torch.Tensor, targets: torch.Tensor, alpha: float = 0.25, gamma: float = 2, reduction: str = 'none') → torch.Tensor[source]¶ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
- Parameters
inputs (Tensor) – A float tensor of arbitrary shape. The predictions for each example.
targets (Tensor) – A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class).
alpha (float) – Weighting factor in range (0,1) to balance positive vs negative examples or -1 for ignore. Default:
0.25
.gamma (float) – Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. Default:
2
.reduction (string) –
'none'
|'mean'
|'sum'
'none'
: No reduction will be applied to the output.'mean'
: The output will be averaged.'sum'
: The output will be summed. Default:'none'
.
- Returns
Loss tensor with the reduction option applied.