Shortcuts

Source code for torcheval.metrics.ranking.click_through_rate

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-ignore-all-errors[16]: Undefined attribute of metric states.

from typing import Iterable, Optional, TypeVar, Union

import torch

from torcheval.metrics.functional.ranking.click_through_rate import (
    _click_through_rate_compute,
    _click_through_rate_update,
)
from torcheval.metrics.metric import Metric


TClickThroughRate = TypeVar("TClickThroughRate")


[docs]class ClickThroughRate(Metric[torch.Tensor]): """ Compute the click through rate given click events. Its functional version is ``torcheval.metrics.functional.click_through_rate``. Args: num_tasks (int): Number of tasks that need weighted_calibration calculation. Default value is 1. Examples:: >>> import torch >>> from torcheval.metrics.ranking import ClickThroughRate >>> metric = ClickThroughRate() >>> input = torch.tensor([0, 1, 0, 1, 1, 0, 0, 1]) >>> metric.update(input) >>> metric.compute() tensor([0.5]) >>> metric = ClickThroughRate() >>> weights = torch.tensor([1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0]) >>> metric.update(input, weights) >>> metric.compute() tensor([0.58333]) >>> metric = ClickThroughRate(num_tasks=2) >>> input = torch.tensor([[0, 1, 0, 1], [1, 0, 0, 1]]) >>> weights = torch.tensor([[1.0, 2.0, 1.0, 2.0],[1.0, 2.0, 1.0, 1.0]]) >>> metric.update(input, weights) >>> metric.compute() tensor([0.6667, 0.4]) """
[docs] def __init__( self: TClickThroughRate, *, num_tasks: int = 1, device: Optional[torch.device] = None, ) -> None: super().__init__(device=device) if num_tasks < 1: raise ValueError( "`num_tasks` value should be greater than and equal to 1, but received {num_tasks}. " ) self.num_tasks = num_tasks self._add_state( "click_total", torch.zeros(self.num_tasks, dtype=torch.float64, device=self.device), ) self._add_state( "weight_total", torch.zeros(self.num_tasks, dtype=torch.float64, device=self.device), )
@torch.inference_mode() # pyre-ignore[14]: `update` overrides method defined in `Metric` inconsistently. def update( self: TClickThroughRate, input: torch.Tensor, weights: Union[torch.Tensor, float, int] = 1.0, ) -> TClickThroughRate: """ Update the metric state with new inputs. Args: input (Tensor): Series of values representing user click (1) or skip (0) of shape (num_events) or (num_objectives, num_events). weights (Tensor, float, int): Weights for each event, single weight or tensor with the same shape as input. """ click_total, weight_total = _click_through_rate_update( input, weights, num_tasks=self.num_tasks ) self.click_total = self.click_total + click_total self.weight_total = self.weight_total + weight_total return self @torch.inference_mode() def compute(self: TClickThroughRate) -> torch.Tensor: """ Return the stacked click through rank scores. If no ``update()`` calls are made before ``compute()`` is called, return tensor(0.0). """ return _click_through_rate_compute(self.click_total, self.weight_total) @torch.inference_mode() def merge_state( self: TClickThroughRate, metrics: Iterable[TClickThroughRate] ) -> TClickThroughRate: """ Merge the metric state with its counterparts from other metric instances. Args: metrics (Iterable[Metric]): metric instances whose states are to be merged. """ for metric in metrics: self.click_total = self.click_total + metric.click_total.to(self.device) self.weight_total = self.weight_total + metric.weight_total.to(self.device) return self

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources