torcheval.metrics.functional.binary_normalized_entropy¶
-
torcheval.metrics.functional.
binary_normalized_entropy
(input: Tensor, target: Tensor, *, weight: Optional[Tensor] = None, num_tasks: int = 1, from_logits: bool = False) Tensor [source]¶ Compute the normalized binary cross entropy between predicted input and ground-truth binary target. Its class version is
torcheval.metrics.binary_normalized_entropy
Parameters: - input (Tensor) – Predicted unnormalized scores (often referred to as logits) or binary class probabilities (num_tasks, num_samples).
- target (Tensor) – Ground truth binary class indices (num_tasks, num_samples).
- weight (Tensor) – Optional. A manual rescaling weight to match input tensor shape (num_tasks, num_samples).
- num_tasks (int) – Number of tasks that need BinaryNormalizedEntropy calculation. Default value is 1.
- from_logit (bool) – A boolean indicator whether the predicted value y_pred is a floating-point logit value (i.e., value in [-inf, inf] when from_logits=True) or a probablity value (i.e., value in [0., 1.] when from_logits=False) Default value is False.
Examples:
>>> import torch >>> from torcheval.metrics.functional import binary_normalized_entropy >>> input = torch.tensor([0.2, 0.3]) >>> target = torch.tensor([1.0, 0.0]) >>> weight = None >>> binary_normalized_entropy(input, target, weight, from_logits=False) tensor(1.4183, dtype=torch.float64) >>> input = torch.tensor([0.2, 0.3]) >>> target = torch.tensor([1.0, 0.0]) >>> weight = torch.tensor([5.0, 1.0]) >>> binary_normalized_entropy(input, target, weight, from_logits=False) tensor(3.1087, dtype=torch.float64) >>> input = torch.tensor([-1.3863, -0.8473]) >>> target = torch.tensor([1.0, 0.0]) >>> weight = None >>> binary_normalized_entropy(input, target, weight, from_logits=True) tensor(1.4183, dtype=torch.float64) >>> input = torch.tensor([[0.2, 0.3], [0.5, 0.1]]) >>> target = torch.tensor([[1.0, 0.0], [0.0, 1.0]]) >>> weight = None >>> binary_normalized_entropy(input, target, weight, from_logits=True) tensor([1.4183, 2.1610], dtype=torch.float64)