Shortcuts

torchrec.metrics

torchrec.metrics.accuracy

class torchrec.metrics.accuracy.AccuracyMetric(world_size: int, my_rank: int, batch_size: int, tasks: List[RecTaskInfo], compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION, window_size: int = 100, fused_update_limit: int = 0, compute_on_all_ranks: bool = False, should_validate_update: bool = False, process_group: Optional[ProcessGroup] = None, **kwargs: Dict[str, Any])

Bases: RecMetric

class torchrec.metrics.accuracy.AccuracyMetricComputation(*args: Any, threshold: float = 0.5, **kwargs: Any)

Bases: RecMetricComputation

This class implements the RecMetricComputation for Accuracy.

The constructor arguments are defined in RecMetricComputation. See the docstring of RecMetricComputation for more detail.

Parameters:

threshold (float) – If provided, computes accuracy metrics cutting off at the specified threshold.

update(*, predictions: Optional[Tensor], labels: Tensor, weights: Optional[Tensor], **kwargs: Dict[str, Any]) None

Override this method to update the state variables of your metric class.

torchrec.metrics.accuracy.compute_accuracy(accuracy_sum: Tensor, weighted_num_samples: Tensor) Tensor
torchrec.metrics.accuracy.compute_accuracy_sum(labels: Tensor, predictions: Tensor, weights: Tensor, threshold: float = 0.5) Tensor
torchrec.metrics.accuracy.get_accuracy_states(labels: Tensor, predictions: Tensor, weights: Optional[Tensor], threshold: float = 0.5) Dict[str, Tensor]

torchrec.metrics.auc

class torchrec.metrics.auc.AUCMetric(world_size: int, my_rank: int, batch_size: int, tasks: List[RecTaskInfo], compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION, window_size: int = 100, fused_update_limit: int = 0, compute_on_all_ranks: bool = False, should_validate_update: bool = False, process_group: Optional[ProcessGroup] = None, **kwargs: Dict[str, Any])

Bases: RecMetric

class torchrec.metrics.auc.AUCMetricComputation(*args: Any, grouped_auc: bool = False, apply_bin: bool = False, fused_update_limit: int = 0, **kwargs: Any)

Bases: RecMetricComputation

This class implements the RecMetricComputation for AUC, i.e. Area Under the Curve.

The constructor arguments are defined in RecMetricComputation. See the docstring of RecMetricComputation for more detail. :param grouped_auc: If True, computes AUC per group and returns average AUC across all groups.

The grouping_keys is provided during state updates along with predictions, labels, weights. This feature is currently not enabled for fused_update_limit.

reset() None

Reset metric state variables to their default value.

update(*, predictions: Optional[Tensor], labels: Tensor, weights: Optional[Tensor], **kwargs: Dict[str, Any]) None
Parameters:
  • predictions (torch.Tensor) – tensor of size (n_task, n_examples)

  • labels (torch.Tensor) – tensor of size (n_task, n_examples)

  • weights (torch.Tensor) – tensor of size (n_task, n_examples)

  • grouping_key (torch.Tensor) – Optional tensor of size (1, n_examples) that specifies the groups of predictions/labels per batch. If provided, the AUC metric also computes AUC per group and returns the average AUC across all groups.

torchrec.metrics.auc.compute_auc(n_tasks: int, predictions: List[Tensor], labels: List[Tensor], weights: List[Tensor], apply_bin: bool = False) Tensor

Computes AUC (Area Under the Curve) for binary classification.

Parameters:
  • n_tasks (int) – number of tasks.

  • predictions (List[torch.Tensor]) – tensor of size (n_tasks, n_examples).

  • labels (List[torch.Tensor]) – tensor of size (n_tasks, n_examples).

  • weights (List[torch.Tensor]) – tensor of size (n_tasks, n_examples).

torchrec.metrics.auc.compute_auc_per_group(n_tasks: int, predictions: List[Tensor], labels: List[Tensor], weights: List[Tensor], grouping_keys: Tensor) Tensor

Computes AUC (Area Under the Curve) for binary classification for groups of predictions/labels. :param n_tasks: number of tasks :type n_tasks: int :param predictions: tensor of size (n_tasks, n_examples) :type predictions: List[torch.Tensor] :param labels (List[torch.Tensor]: tensor of size (n_tasks, n_examples) :param weights: tensor of size (n_tasks, n_examples) :type weights: List[torch.Tensor] :param grouping_keys: tensor of size (n_examples,) :type grouping_keys: torch.Tensor

Returns:

tensor of size (n_tasks,), average of AUCs per group.

Return type:

torch.Tensor

torchrec.metrics.auprc

class torchrec.metrics.auprc.AUPRCMetric(world_size: int, my_rank: int, batch_size: int, tasks: List[RecTaskInfo], compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION, window_size: int = 100, fused_update_limit: int = 0, compute_on_all_ranks: bool = False, should_validate_update: bool = False, process_group: Optional[ProcessGroup] = None, **kwargs: Dict[str, Any])

Bases: RecMetric

class torchrec.metrics.auprc.AUPRCMetricComputation(*args: Any, grouped_auprc: bool = False, fused_update_limit: int = 0, **kwargs: Any)

Bases: RecMetricComputation

This class implements the RecMetricComputation for AUPRC, i.e. Area Under the Curve.

The constructor arguments are defined in RecMetricComputation. See the docstring of RecMetricComputation for more detail. :param grouped_auprc: If True, computes AUPRC per group and returns average AUPRC across all groups.

The grouping_keys is provided during state updates along with predictions, labels, weights. This feature is currently not enabled for fused_update_limit.

reset() None

Reset metric state variables to their default value.

update(*, predictions: Optional[Tensor], labels: Tensor, weights: Optional[Tensor], **kwargs: Dict[str, Any]) None
Parameters:
  • predictions (torch.Tensor) – tensor of size (n_task, n_examples)

  • labels (torch.Tensor) – tensor of size (n_task, n_examples)

  • weights (torch.Tensor) – tensor of size (n_task, n_examples)

  • grouping_key (torch.Tensor) – Optional tensor of size (1, n_examples) that specifies the groups of predictions/labels per batch. If provided, the PR AUC metric also computes PR AUC per group and returns the average PR AUC across all groups.

torchrec.metrics.auprc.compute_auprc(n_tasks: int, predictions: Tensor, labels: Tensor, weights: Tensor) Tensor

Computes AUPRC (Area Under the Curve) for binary classification.

Parameters:
  • n_tasks (int) – number of tasks.

  • predictions (torch.Tensor) – tensor of size (n_tasks, n_examples).

  • labels (torch.Tensor) – tensor of size (n_tasks, n_examples).

  • weights (torch.Tensor) – tensor of size (n_tasks, n_examples).

torchrec.metrics.auprc.compute_auprc_per_group(n_tasks: int, predictions: Tensor, labels: Tensor, weights: Tensor, grouping_keys: Tensor) Tensor

Computes AUPRC (Area Under the Curve) for binary classification for groups of predictions/labels. :param n_tasks: number of tasks :type n_tasks: int :param predictions: tensor of size (n_tasks, n_examples) :type predictions: torch.Tensor :param labels: tensor of size (n_tasks, n_examples) :type labels: torch.Tensor :param weights: tensor of size (n_tasks, n_examples) :type weights: torch.Tensor :param grouping_keys: tensor of size (n_examples,) :type grouping_keys: torch.Tensor

Returns:

tensor of size (n_tasks,), average of AUPRCs per group.

Return type:

torch.Tensor

torchrec.metrics.calibration

class torchrec.metrics.calibration.CalibrationMetric(world_size: int, my_rank: int, batch_size: int, tasks: List[RecTaskInfo], compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION, window_size: int = 100, fused_update_limit: int = 0, compute_on_all_ranks: bool = False, should_validate_update: bool = False, process_group: Optional[ProcessGroup] = None, **kwargs: Dict[str, Any])

Bases: RecMetric

class torchrec.metrics.calibration.CalibrationMetricComputation(*args: Any, **kwargs: Any)

Bases: RecMetricComputation

This class implements the RecMetricComputation for Calibration, which is the ratio between the prediction and the labels (conversions).

The constructor arguments are defined in RecMetricComputation. See the docstring of RecMetricComputation for more detail.

update(*, predictions: Optional[Tensor], labels: Tensor, weights: Optional[Tensor], **kwargs: Dict[str, Any]) None

Override this method to update the state variables of your metric class.

torchrec.metrics.calibration.compute_calibration(calibration_num: Tensor, calibration_denom: Tensor) Tensor
torchrec.metrics.calibration.get_calibration_states(labels: Tensor, predictions: Tensor, weights: Tensor) Dict[str, Tensor]

torchrec.metrics.ctr

class torchrec.metrics.ctr.CTRMetric(world_size: int, my_rank: int, batch_size: int, tasks: List[RecTaskInfo], compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION, window_size: int = 100, fused_update_limit: int = 0, compute_on_all_ranks: bool = False, should_validate_update: bool = False, process_group: Optional[ProcessGroup] = None, **kwargs: Dict[str, Any])

Bases: RecMetric

class torchrec.metrics.ctr.CTRMetricComputation(*args: Any, **kwargs: Any)

Bases: RecMetricComputation

This class implements the RecMetricComputation for CTR, i.e. Click Through Rate, which is the ratio between the predicted positive examples and the total examples.

The constructor arguments are defined in RecMetricComputation. See the docstring of RecMetricComputation for more detail.

update(*, predictions: Optional[Tensor], labels: Tensor, weights: Optional[Tensor], **kwargs: Dict[str, Any]) None

Override this method to update the state variables of your metric class.

torchrec.metrics.ctr.compute_ctr(ctr_num: Tensor, ctr_denom: Tensor) Tensor
torchrec.metrics.ctr.get_ctr_states(labels: Tensor, predictions: Tensor, weights: Tensor) Dict[str, Tensor]

torchrec.metrics.mae

class torchrec.metrics.mae.MAEMetric(world_size: int, my_rank: int, batch_size: int, tasks: List[RecTaskInfo], compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION, window_size: int = 100, fused_update_limit: int = 0, compute_on_all_ranks: bool = False, should_validate_update: bool = False, process_group: Optional[ProcessGroup] = None, **kwargs: Dict[str, Any])

Bases: RecMetric

class torchrec.metrics.mae.MAEMetricComputation(*args: Any, **kwargs: Any)

Bases: RecMetricComputation

This class implements the RecMetricComputation for MAE, i.e. Mean Absolute Error.

The constructor arguments are defined in RecMetricComputation. See the docstring of RecMetricComputation for more detail.

update(*, predictions: Optional[Tensor], labels: Tensor, weights: Optional[Tensor]) None

Override this method to update the state variables of your metric class.

torchrec.metrics.mae.compute_error_sum(labels: Tensor, predictions: Tensor, weights: Tensor) Tensor
torchrec.metrics.mae.compute_mae(error_sum: Tensor, weighted_num_samples: Tensor) Tensor
torchrec.metrics.mae.get_mae_states(labels: Tensor, predictions: Tensor, weights: Tensor) Dict[str, Tensor]

torchrec.metrics.mse

class torchrec.metrics.mse.MSEMetric(world_size: int, my_rank: int, batch_size: int, tasks: List[RecTaskInfo], compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION, window_size: int = 100, fused_update_limit: int = 0, compute_on_all_ranks: bool = False, should_validate_update: bool = False, process_group: Optional[ProcessGroup] = None, **kwargs: Dict[str, Any])

Bases: RecMetric

class torchrec.metrics.mse.MSEMetricComputation(*args: Any, **kwargs: Any)

Bases: RecMetricComputation

This class implements the RecMetricComputation for MSE, i.e. Mean Squared Error.

The constructor arguments are defined in RecMetricComputation. See the docstring of RecMetricComputation for more detail.

update(*, predictions: Optional[Tensor], labels: Tensor, weights: Optional[Tensor], **kwargs: Dict[str, Any]) None

Override this method to update the state variables of your metric class.

torchrec.metrics.mse.compute_error_sum(labels: Tensor, predictions: Tensor, weights: Tensor) Tensor
torchrec.metrics.mse.compute_mse(error_sum: Tensor, weighted_num_samples: Tensor) Tensor
torchrec.metrics.mse.compute_rmse(error_sum: Tensor, weighted_num_samples: Tensor) Tensor
torchrec.metrics.mse.get_mse_states(labels: Tensor, predictions: Tensor, weights: Tensor) Dict[str, Tensor]

torchrec.metrics.multiclass_recall

class torchrec.metrics.multiclass_recall.MulticlassRecallMetric(world_size: int, my_rank: int, batch_size: int, tasks: List[RecTaskInfo], compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION, window_size: int = 100, fused_update_limit: int = 0, compute_on_all_ranks: bool = False, should_validate_update: bool = False, process_group: Optional[ProcessGroup] = None, **kwargs: Dict[str, Any])

Bases: RecMetric

class torchrec.metrics.multiclass_recall.MulticlassRecallMetricComputation(*args: Any, **kwargs: Any)

Bases: RecMetricComputation

update(*, predictions: Optional[Tensor], labels: Tensor, weights: Optional[Tensor], **kwargs: Dict[str, Any]) None

Override this method to update the state variables of your metric class.

torchrec.metrics.multiclass_recall.compute_multiclass_recall_at_k(tp_at_k: Tensor, total_weights: Tensor) Tensor
torchrec.metrics.multiclass_recall.compute_true_positives_at_k(predictions: Tensor, labels: Tensor, weights: Tensor, n_classes: int) Tensor

Compute and return a list of weighted true positives (true predictions) at k. When k = 0, tp is counted when the 1st predicted class matches the label. When k = 1, tp is counted when either the 1st or 2nd predicted class matches the label.

Parameters:
  • predictions (Tensor) – Tensor of label predictions with shape of (n_sample, n_class) or (n_task, n_sample, n_class).

  • labels (Tensor) – Tensor of ground truth labels with shape of (n_sample, ) or (n_task, n_sample).

  • weights (Tensor) – Tensor of weight on each sample, with shape of (n_sample, ) or (n_task, n_sample).

  • n_classes (int) – Number of classes.

Output:

true_positives_list (Tensor): Tensor of true positives with shape of (n_class, ) or (n_task, n_class).

Examples

>>> predictions = torch.tensor([[0.9, 0.1, 0, 0, 0], [0.1, 0.2, 0.25, 0.15, 0.3], [0, 1.0, 0, 0, 0], [0, 0, 0.2, 0.7, 0.1]])
>>> labels = torch.tensor([0, 3, 1, 2])
>>> weights = torch.tensor([1, 0.25, 0.5, 0.25])
>>> n_classes = 5
>>> true_positives_list = compute_multiclass_k_sum(predictions, labels, n_classes)
>>> true_positives_list
tensor([1.5000, 1.7500, 1.7500, 2.0000, 2.0000])
torchrec.metrics.multiclass_recall.get_multiclass_recall_states(predictions: Tensor, labels: Tensor, weights: Tensor, n_classes: int) Dict[str, Tensor]

torchrec.metrics.ndcg

class torchrec.metrics.ndcg.NDCGComputation(*args: Any, exponential_gain: bool = False, session_key: str = 'session_id', k: int = - 1, report_ndcg_as_decreasing_curve: bool = True, remove_single_length_sessions: bool = False, scale_by_weights_tensor: bool = False, is_negative_task_mask: Optional[List[bool]] = None, **kwargs: Any)

Bases: RecMetricComputation

This class implements the RecMetricComputation for NDCG @ K (i.e., Normalized Discounted Cumulative Gain @ K).

Specially this reports (1 - NDCG) so that TensorBoard can capture a decreasing “loss” as opposed to an increasing “gain” to visualize similarly to normalized entropy (NE) / pointwise measures.

The constructor arguments are defined in RecMetricComputation. See the docstring of RecMetricComputation for more detail.

update(*, predictions: Optional[Tensor], labels: Tensor, weights: Optional[Tensor], **kwargs: Dict[str, Any]) None
Parameters:
  • predictions – Tensor of size (n_task, n_examples)

  • labels – Tensor of size (n_task, n_examples)

  • weights – Tensor of size (n_task, n_examples)

Returns:

Nothing => updates state.

class torchrec.metrics.ndcg.NDCGMetric(world_size: int, my_rank: int, batch_size: int, tasks: List[RecTaskInfo], compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION, window_size: int = 100, fused_update_limit: int = 0, compute_on_all_ranks: bool = False, should_validate_update: bool = False, process_group: Optional[ProcessGroup] = None, **kwargs: Dict[str, Any])

Bases: RecMetric

torchrec.metrics.ne

class torchrec.metrics.ne.NEMetric(world_size: int, my_rank: int, batch_size: int, tasks: List[RecTaskInfo], compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION, window_size: int = 100, fused_update_limit: int = 0, compute_on_all_ranks: bool = False, should_validate_update: bool = False, process_group: Optional[ProcessGroup] = None, **kwargs: Dict[str, Any])

Bases: RecMetric

class torchrec.metrics.ne.NEMetricComputation(*args: Any, include_logloss: bool = False, allow_missing_label_with_zero_weight: bool = False, **kwargs: Any)

Bases: RecMetricComputation

This class implements the RecMetricComputation for NE, i.e. Normalized Entropy.

The constructor arguments are defined in RecMetricComputation. See the docstring of RecMetricComputation for more detail.

Parameters:

include_logloss (bool) – return vanilla logloss as one of metrics results, on top of NE.

update(*, predictions: Optional[Tensor], labels: Tensor, weights: Optional[Tensor], **kwargs: Dict[str, Any]) None

Override this method to update the state variables of your metric class.

torchrec.metrics.ne.compute_cross_entropy(labels: Tensor, predictions: Tensor, weights: Tensor, eta: float) Tensor
torchrec.metrics.ne.compute_logloss(ce_sum: Tensor, pos_labels: Tensor, neg_labels: Tensor, eta: float) Tensor
torchrec.metrics.ne.compute_ne(ce_sum: Tensor, weighted_num_samples: Tensor, pos_labels: Tensor, neg_labels: Tensor, eta: float, allow_missing_label_with_zero_weight: bool = False) Tensor
torchrec.metrics.ne.get_ne_states(labels: Tensor, predictions: Tensor, weights: Tensor, eta: float) Dict[str, Tensor]

torchrec.metrics.recall

class torchrec.metrics.recall.RecallMetric(world_size: int, my_rank: int, batch_size: int, tasks: List[RecTaskInfo], compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION, window_size: int = 100, fused_update_limit: int = 0, compute_on_all_ranks: bool = False, should_validate_update: bool = False, process_group: Optional[ProcessGroup] = None, **kwargs: Dict[str, Any])

Bases: RecMetric

class torchrec.metrics.recall.RecallMetricComputation(*args: Any, threshold: float = 0.5, **kwargs: Any)

Bases: RecMetricComputation

This class implements the RecMetricComputation for Recall.

The constructor arguments are defined in RecMetricComputation. See the docstring of RecMetricComputation for more detail.

Parameters:

threshold (float) – If provided, computes Recall metrics cutting off at the specified threshold.

update(*, predictions: Optional[Tensor], labels: Tensor, weights: Optional[Tensor], **kwargs: Dict[str, Any]) None

Override this method to update the state variables of your metric class.

torchrec.metrics.recall.compute_false_neg_sum(labels: Tensor, predictions: Tensor, weights: Tensor, threshold: float = 0.5) Tensor
torchrec.metrics.recall.compute_recall(num_true_positives: Tensor, num_false_negitives: Tensor) Tensor
torchrec.metrics.recall.compute_true_pos_sum(labels: Tensor, predictions: Tensor, weights: Tensor, threshold: float = 0.5) Tensor
torchrec.metrics.recall.get_recall_states(labels: Tensor, predictions: Tensor, weights: Optional[Tensor], threshold: float = 0.5) Dict[str, Tensor]

torchrec.metrics.precision

class torchrec.metrics.precision.PrecisionMetric(world_size: int, my_rank: int, batch_size: int, tasks: List[RecTaskInfo], compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION, window_size: int = 100, fused_update_limit: int = 0, compute_on_all_ranks: bool = False, should_validate_update: bool = False, process_group: Optional[ProcessGroup] = None, **kwargs: Dict[str, Any])

Bases: RecMetric

class torchrec.metrics.precision.PrecisionMetricComputation(*args: Any, threshold: float = 0.5, **kwargs: Any)

Bases: RecMetricComputation

This class implements the RecMetricComputation for Precision.

The constructor arguments are defined in RecMetricComputation. See the docstring of RecMetricComputation for more detail.

Parameters:

threshold (float) – If provided, computes Precision metrics cutting off at the specified threshold.

update(*, predictions: Optional[Tensor], labels: Tensor, weights: Optional[Tensor], **kwargs: Dict[str, Any]) None

Override this method to update the state variables of your metric class.

torchrec.metrics.precision.compute_false_pos_sum(labels: Tensor, predictions: Tensor, weights: Tensor, threshold: float = 0.5) Tensor
torchrec.metrics.precision.compute_precision(num_true_positives: Tensor, num_false_positives: Tensor) Tensor
torchrec.metrics.precision.compute_true_pos_sum(labels: Tensor, predictions: Tensor, weights: Tensor, threshold: float = 0.5) Tensor
torchrec.metrics.precision.get_precision_states(labels: Tensor, predictions: Tensor, weights: Optional[Tensor], threshold: float = 0.5) Dict[str, Tensor]

torchrec.metrics.rauc

class torchrec.metrics.rauc.RAUCMetric(world_size: int, my_rank: int, batch_size: int, tasks: List[RecTaskInfo], compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION, window_size: int = 100, fused_update_limit: int = 0, compute_on_all_ranks: bool = False, should_validate_update: bool = False, process_group: Optional[ProcessGroup] = None, **kwargs: Dict[str, Any])

Bases: RecMetric

class torchrec.metrics.rauc.RAUCMetricComputation(*args: Any, grouped_rauc: bool = False, fused_update_limit: int = 0, **kwargs: Any)

Bases: RecMetricComputation

This class implements the RecMetricComputation for RAUC, i.e. Regression AUC.

The constructor arguments are defined in RecMetricComputation. See the docstring of RecMetricComputation for more detail. :param grouped_rauc: If True, computes RAUC per group and returns average RAUC across all groups.

The grouping_keys is provided during state updates along with predictions, labels, weights. This feature is currently not enabled for fused_update_limit.

reset() None

Reset metric state variables to their default value.

update(*, predictions: Optional[Tensor], labels: Tensor, weights: Optional[Tensor], **kwargs: Dict[str, Any]) None
Parameters:
  • predictions (torch.Tensor) – tensor of size (n_task, n_examples)

  • labels (torch.Tensor) – tensor of size (n_task, n_examples)

  • weights (torch.Tensor) – tensor of size (n_task, n_examples)

  • grouping_key (torch.Tensor) – Optional tensor of size (1, n_examples) that specifies the groups of predictions/labels per batch. If provided, the RAUC metric also computes RAUC per group and returns the average RAUC across all groups.

torchrec.metrics.rauc.compute_rauc(n_tasks: int, predictions: List[Tensor], labels: List[Tensor], weights: List[Tensor]) Tensor

Computes RAUC (Regression AUC) for regression tasks.

Parameters:
  • predictions (List[torch.Tensor]) – tensor of size (n_tasks, n_examples).

  • labels (List[torch.Tensor]) – tensor of size (n_tasks, n_examples).

  • weights (List[torch.Tensor]) – tensor of size (n_tasks, n_examples).

torchrec.metrics.rauc.compute_rauc_per_group(n_tasks: int, predictions: List[Tensor], labels: List[Tensor], weights: List[Tensor], grouping_keys: Tensor) Tensor

Computes RAUC (Regression AUC) for regression tasks for groups of predictions/labels. :param n_tasks: number of tasks :type n_tasks: int :param predictions: tensor of size (n_tasks, n_examples) :type predictions: List[torch.Tensor] :param labels (List[torch.Tensor]: tensor of size (n_tasks, n_examples) :param weights: tensor of size (n_tasks, n_examples) :type weights: List[torch.Tensor] :param grouping_keys: tensor of size (n_examples,) :type grouping_keys: torch.Tensor

Returns:

tensor of size (n_tasks,), average of RAUCs per group.

Return type:

torch.Tensor

torchrec.metrics.rauc.conquer_and_count(input: List[float], left_index: int, mid_index: int, right_index: int) int
torchrec.metrics.rauc.count_reverse_pairs_divide_and_conquer(input: List[float]) float
torchrec.metrics.rauc.divide(input: List[float], low: int, high: int) int

torchrec.metrics.throughput

class torchrec.metrics.throughput.ThroughputMetric(*, batch_size: int, world_size: int, window_seconds: int, warmup_steps: int = 100)

Bases: Module

The module to calculate throughput. Throughput is defined as the trained examples across all ranks per second. For example, if the batch size on each rank is 512 and there are 32 ranks, throughput is 512 * 32 / time_to_train_one_step.

Parameters:
  • batch_size (int) – batch size for the trainer

  • world_size (int) – the number of trainers

  • window_seconds (int) – Throughput use time-based window for window_throughput. This argument specify the window size in seconds.

  • warmup_steps (int) – the number of warmup batches. No Throughput will be calculated before the warmup batches count reached.

Call Args:

Not supported.

Returns:

Not supported.

Example:

throughput = ThroughputMetric(
              batch_size=128,
              world_size=4,
              window_seconds=100,
              warmup_steps=100
          )
compute() Dict[str, Tensor]
update() None

torchrec.metrics.weighted_avg

class torchrec.metrics.weighted_avg.WeightedAvgMetric(world_size: int, my_rank: int, batch_size: int, tasks: List[RecTaskInfo], compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION, window_size: int = 100, fused_update_limit: int = 0, compute_on_all_ranks: bool = False, should_validate_update: bool = False, process_group: Optional[ProcessGroup] = None, **kwargs: Dict[str, Any])

Bases: RecMetric

class torchrec.metrics.weighted_avg.WeightedAvgMetricComputation(*args: Any, **kwargs: Any)

Bases: RecMetricComputation

update(*, predictions: Optional[Tensor], labels: Tensor, weights: Optional[Tensor], **kwargs: Dict[str, Any]) None

Override this method to update the state variables of your metric class.

torchrec.metrics.weighted_avg.get_mean(value_sum: Tensor, num_samples: Tensor) Tensor

torchrec.metrics.xauc

class torchrec.metrics.xauc.XAUCMetric(world_size: int, my_rank: int, batch_size: int, tasks: List[RecTaskInfo], compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION, window_size: int = 100, fused_update_limit: int = 0, compute_on_all_ranks: bool = False, should_validate_update: bool = False, process_group: Optional[ProcessGroup] = None, **kwargs: Dict[str, Any])

Bases: RecMetric

class torchrec.metrics.xauc.XAUCMetricComputation(*args: Any, **kwargs: Any)

Bases: RecMetricComputation

This class implements the RecMetricComputation for XAUC.

The constructor arguments are defined in RecMetricComputation. See the docstring of RecMetricComputation for more detail.

update(*, predictions: Optional[Tensor], labels: Tensor, weights: Optional[Tensor]) None

Override this method to update the state variables of your metric class.

torchrec.metrics.xauc.compute_error_sum(labels: Tensor, predictions: Tensor, weights: Tensor) Tensor
torchrec.metrics.xauc.compute_weighted_num_pairs(weights: Tensor) Tensor
torchrec.metrics.xauc.compute_xauc(error_sum: Tensor, weighted_num_pairs: Tensor) Tensor
torchrec.metrics.xauc.get_xauc_states(labels: Tensor, predictions: Tensor, weights: Tensor) Dict[str, Tensor]

torchrec.metrics.metric_module

class torchrec.metrics.metric_module.RecMetricModule(batch_size: int, world_size: int, rec_tasks: Optional[List[RecTaskInfo]] = None, rec_metrics: Optional[RecMetricList] = None, throughput_metric: Optional[ThroughputMetric] = None, state_metrics: Optional[Dict[str, StateMetric]] = None, compute_interval_steps: int = 100, min_compute_interval: float = 0.0, max_compute_interval: float = inf, memory_usage_limit_mb: float = 512)

Bases: Module

For the current recommendation models, we assume there will be three types of metrics, 1.) RecMetric, 2.) Throughput, 3.) StateMetric.

RecMetric is a metric that is computed from the model outputs (labels, predictions, weights).

Throughput is being a standalone type as its unique characteristic, time-based.

StateMetric is a metric that is computed based on a model componenet (e.g., Optimizer) internal logic.

Parameters:
  • batch_size (int) – batch size used by this trainer.

  • world_size (int) – the number of trainers.

  • rec_tasks (Optional[List[RecTaskInfo]]) – the information of the model tasks.

  • rec_metrics (Optional[RecMetricList]) – the list of the RecMetrics.

  • throughput_metric (Optional[ThroughputMetric]) – the ThroughputMetric.

  • state_metrics (Optional[Dict[str, StateMetric]]) – the dict of StateMetrics.

  • compute_interval_steps (int) – the intervals between two compute calls in the unit of batch number

  • memory_usage_limit_mb (float) – the memory usage limit for OOM check

Call Args:

Not supported.

Returns:

Not supported.

Example

>>> config = dataclasses.replace(
>>>     DefaultMetricsConfig, state_metrics=[StateMetricEnum.OPTIMIZERS]
>>> )
>>>
>>> metricModule = generate_metric_module(
>>>     metric_class=RecMetricModule,
>>>     metrics_config=config,
>>>     batch_size=128,
>>>     world_size=64,
>>>     my_rank=0,
>>>     state_metrics_mapping={StateMetricEnum.OPTIMIZERS: mock_optimizer},
>>>     device=torch.device("cpu"),
>>>     pg=dist.new_group([0]),
>>> )
batch_size: int
check_memory_usage(compute_count: int) None
compute() Dict[str, Union[Tensor, float]]

compute() is called when the global metrics are required, usually right before logging the metrics results to the data sink.

compute_count: int
get_memory_usage() int

Total memory of unique RecMetric tensors in bytes

get_required_inputs() Optional[List[str]]
last_compute_time: float
local_compute() Dict[str, Union[Tensor, float]]

local_compute() is called when per-trainer metrics are required. It’s can be used for debugging. Currently only rec_metrics is supported.

memory_usage_limit_mb: float
memory_usage_mb_avg: float
oom_count: int
rec_metrics: RecMetricList
rec_tasks: List[RecTaskInfo]
reset() None
should_compute() bool
state_metrics: Dict[str, StateMetric]
sync() None
throughput_metric: Optional[ThroughputMetric]
unsync() None
update(model_out: Dict[str, Tensor], **kwargs: Any) None

update() is called per batch, usually right after forward() to update the local states of metrics based on the model_output.

Throughput.update() is also called due to the implementation sliding window throughput.

world_size: int
class torchrec.metrics.metric_module.StateMetric

Bases: ABC

The interface of state metrics for a component (e.g., optimizer, qat).

abstract get_metrics() Dict[str, Union[Tensor, float]]
torchrec.metrics.metric_module.generate_metric_module(metric_class: Type[RecMetricModule], metrics_config: MetricsConfig, batch_size: int, world_size: int, my_rank: int, state_metrics_mapping: Dict[StateMetricEnum, StateMetric], device: device, process_group: Optional[ProcessGroup] = None) RecMetricModule

torchrec.metrics.rec_metric

class torchrec.metrics.rec_metric.MetricComputationReport(name: torchrec.metrics.metrics_namespace.MetricNameBase, metric_prefix: torchrec.metrics.metrics_namespace.MetricPrefix, value: torch.Tensor, description: Union[str, NoneType] = None)

Bases: object

description: Optional[str] = None
metric_prefix: MetricPrefix
name: MetricNameBase
value: Tensor
class torchrec.metrics.rec_metric.RecMetric(world_size: int, my_rank: int, batch_size: int, tasks: List[RecTaskInfo], compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION, window_size: int = 100, fused_update_limit: int = 0, compute_on_all_ranks: bool = False, should_validate_update: bool = False, process_group: Optional[ProcessGroup] = None, **kwargs: Dict[str, Any])

Bases: Module, ABC

The main class template to implement a recommendation metric. This class contains the recommendation tasks information (RecTaskInfo) and the actual computation object (RecMetricComputation). RecMetric processes all the information related to RecTaskInfo and models, and passes the required signals to the computation object, allowing the implementation of RecMetricComputation to focus on the mathematical meaning.

A new metric that inherits RecMetric must override the following attributes in its own __init__(): _namespace and _metrics_computations. No other methods should be overridden.

Parameters:
  • world_size (int) – the number of trainers.

  • my_rank (int) – the rank of this trainer.

  • batch_size (int) – batch size used by this trainer.

  • tasks (List[RecTaskInfo]) – the information of the model tasks.

  • compute_mode (RecComputeMode) – the computation mode. See RecComputeMode.

  • window_size (int) – the window size for the window metric.

  • fused_update_limit (int) – the maximum number of updates to be fused.

  • compute_on_all_ranks (bool) – whether to compute metrics on all ranks. This is necessary if the non-leader rank wants to consume global metrics result.

  • should_validate_update (bool) – whether to check the inputs of update() and skip the update if the inputs are invalid. Invalid inputs include the case where all examples have 0 weights for a batch.

  • process_group (Optional[ProcessGroup]) – the process group used for the communication. Will use the default process group if not specified.

Example:

ne = NEMetric(
    world_size=4,
    my_rank=0,
    batch_size=128,
    tasks=DefaultTaskInfo,
)
LABELS: str = 'labels'
PREDICTIONS: str = 'predictions'
WEIGHTS: str = 'weights'
compute() Dict[str, Tensor]
get_memory_usage() Dict[Tensor, int]

Estimates the memory of the rec metric instance’s underlying tensors; returns the map of tensor to size

get_required_inputs() Set[str]
local_compute() Dict[str, Tensor]
reset() None
state_dict(destination: Optional[Dict[str, Tensor]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Tensor]

Return a dictionary containing references to the whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.

Note

The returned object is a shallow copy. It contains references to the module’s parameters and buffers.

Warning

Currently state_dict() also accepts positional arguments for destination, prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases.

Warning

Please avoid the use of argument destination as it is not designed for end-users.

Parameters:
  • destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None.

  • prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ''.

  • keep_vars (bool, optional) – by default the Tensor s returned in the state dict are detached from autograd. If it’s set to True, detaching will not be performed. Default: False.

Returns:

a dictionary containing a whole state of the module

Return type:

dict

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']
sync() None
unsync() None
update(*, predictions: Union[Tensor, Dict[str, Tensor]], labels: Union[Tensor, Dict[str, Tensor]], weights: Optional[Union[Tensor, Dict[str, Tensor]]], **kwargs: Dict[str, Any]) None
class torchrec.metrics.rec_metric.RecMetricComputation(my_rank: int, batch_size: int, n_tasks: int, window_size: int, compute_on_all_ranks: bool = False, should_validate_update: bool = False, process_group: Optional[ProcessGroup] = None, fused_update_limit: int = 0, allow_missing_label_with_zero_weight: bool = False, *args: Any, **kwargs: Any)

Bases: Metric, ABC

The internal computation class template. A metric implementation should overwrite update() and compute(). These two APIs focus on the actual mathematical meaning of the metric, without detailed knowledge of model output and task information.

Parameters:
  • my_rank (int) – the rank of this trainer.

  • batch_size (int) – batch size used by this trainer.

  • n_tasks (int) – the number tasks this communication object will have to compute.

  • window_size (int) – the window size for the window metric.

  • compute_on_all_ranks (bool) – whether to compute metrics on all ranks. This is necessary if the non-leader rank wants to consume the metrics results.

  • should_validate_update (bool) – whether to check the inputs of update() and skip the update if the inputs are invalid. Invalid inputs include the case where all examples have 0 weights for a batch.

  • process_group (Optional[ProcessGroup]) – the process group used for the communication. Will use the default process group if not specified.

compute() List[MetricComputationReport]

Override this method to compute the final metric value.

This method will automatically synchronize state variables when running in distributed backend.

get_window_state(state_name: str) Tensor
static get_window_state_name(state_name: str) str
local_compute() List[MetricComputationReport]
pre_compute() None

If a metric need to do some work before compute(), the metric has to override this pre_compute(). One possible usage is to do some pre-processing of the local state before compute() as TorchMetric wraps RecMetricComputation.compute() and will do the global aggregation before RecMetricComputation.compute() is called.

reset() None

Reset metric state variables to their default value.

abstract update(*, predictions: Optional[Tensor], labels: Tensor, weights: Optional[Tensor], **kwargs: Dict[str, Any]) None

Override this method to update the state variables of your metric class.

exception torchrec.metrics.rec_metric.RecMetricException

Bases: Exception

class torchrec.metrics.rec_metric.RecMetricList(rec_metrics: List[RecMetric])

Bases: Module

A list module to encapulate multiple RecMetric instances and provide the same interfaces as RecMetric.

Parameters:

(List[RecMetric] (rec_metrics) – the list of the input RecMetrics.

Call Args:

Not supported.

Returns:

Not supported.

Example:

ne = NEMetric(
         world_size=4,
         my_rank=0,
         batch_size=128,
         tasks=DefaultTaskInfo
     )
metrics = RecMetricList([ne])
compute() Dict[str, Tensor]
get_required_inputs() Optional[List[str]]
local_compute() Dict[str, Tensor]
rec_metrics: ModuleList
required_inputs: Optional[List[str]]
reset() None
sync() None
unsync() None
update(*, predictions: Union[Tensor, Dict[str, Tensor]], labels: Union[Tensor, Dict[str, Tensor]], weights: Union[Tensor, Dict[str, Tensor]], **kwargs: Dict[str, Any]) None
class torchrec.metrics.rec_metric.WindowBuffer(max_size: int, max_buffer_count: int)

Bases: object

aggregate_state(window_state: Tensor, curr_state: Tensor, size: int) None
property buffers: Deque[Tensor]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources