Shortcuts

Source code for torcheval.metrics.window.click_through_rate

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-ignore-all-errors[16]: Undefined attribute of metric states.

from typing import Iterable, Optional, Tuple, TypeVar, Union

import torch

from torcheval.metrics.functional.ranking.click_through_rate import (
    _click_through_rate_compute,
    _click_through_rate_update,
)
from torcheval.metrics.metric import Metric


TWindowedClickThroughRate = TypeVar("TWindowedClickThroughRate")


[docs]class WindowedClickThroughRate( Metric[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] ): """ The windowed version of ClickThroughRate that provides both windowed and lifetime values. Windowed value is calculated from the input and target of the last window_size number of `update()` calls. Lifetime value is calculated from all past input and target of `update()` calls. Compute the click through rate given click events. Its functional version is ``torcheval.metrics.functional.click_through_rate``. Args: num_tasks (int): Number of tasks that need click through rate calculation. Default value is 1. max_num_updates (int): The max window size that can accommodate the number of updates. enable_lifetime (bool): A boolean indicator whether to calculate lifetime values. Examples:: >>> import torch >>> from torcheval.metrics import WindowedClickThroughRate >>> metric = WindowedClickThroughRate(max_num_updates=2) >>> metric.update(torch.tensor([0, 1, 0, 1, 1, 0, 0, 1])) >>> metric.update(torch.tensor([0, 1, 0, 1, 1, 1, 1, 1])) >>> metric.update(torch.tensor([0, 1, 0, 1, 0, 0, 0, 1])) >>> metric.compute() tensor([0.5625]) """
[docs] def __init__( self: TWindowedClickThroughRate, *, num_tasks: int = 1, max_num_updates: int = 100, enable_lifetime: bool = True, device: Optional[torch.device] = None, ) -> None: super().__init__(device=device) if num_tasks < 1: raise ValueError( "`num_tasks` value should be greater than and equal to 1, but received {num_tasks}. " ) if max_num_updates < 1: raise ValueError( "`max_num_updates` value should be greater than and equal to 1, but received {max_num_updates}. " ) self.num_tasks = num_tasks self._add_state("max_num_updates", max_num_updates) self.next_inserted = 0 self.enable_lifetime = enable_lifetime self._add_state("total_updates", 0) if self.enable_lifetime: self._add_state( "click_total", torch.zeros(self.num_tasks, dtype=torch.float64, device=self.device), ) self._add_state( "weight_total", torch.zeros(self.num_tasks, dtype=torch.float64, device=self.device), ) self._add_state( "windowed_click_total", torch.zeros( self.num_tasks, self.max_num_updates, dtype=torch.float64, device=self.device, ), ) self._add_state( "windowed_weight_total", torch.zeros( self.num_tasks, self.max_num_updates, dtype=torch.float64, device=self.device, ), )
@torch.inference_mode() # pyre-ignore[14]: `update` overrides method defined in `Metric` inconsistently. def update( self: TWindowedClickThroughRate, input: torch.Tensor, weights: Union[torch.Tensor, float, int] = 1.0, ) -> TWindowedClickThroughRate: """ Update the metric state with new inputs. Args: input (Tensor): Series of values representing user click (1) or skip (0) of shape (num_events) or (num_objectives, num_events). weights (Tensor, float, int): Weights for each event, single weight or tensor with the same shape as input. """ click_total, weight_total = _click_through_rate_update( input, weights, num_tasks=self.num_tasks ) if self.enable_lifetime: self.click_total += click_total self.weight_total += weight_total self.windowed_click_total[:, self.next_inserted] = click_total self.windowed_weight_total[:, self.next_inserted] = weight_total self.next_inserted += 1 self.next_inserted %= self.max_num_updates self.total_updates += 1 return self @torch.inference_mode() def compute( self: TWindowedClickThroughRate, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Return the stacked click through rank scores. If no ``update()`` calls are made before ``compute()`` is called, return tensor(0.0). """ if self.total_updates == 0: if self.enable_lifetime: return torch.empty(0), torch.empty(0) else: return torch.empty(0) # For the case that window has been filled more than once if self.total_updates >= self.max_num_updates: click_total = self.windowed_click_total.sum(dim=-1) weight_total = self.windowed_weight_total.sum(dim=-1) else: # For the case that window hasn't been filled click_total = self.windowed_click_total[:, : self.next_inserted].sum(dim=-1) weight_total = self.windowed_weight_total[:, : self.next_inserted].sum( dim=-1 ) windowed_click_through_rate = _click_through_rate_compute( click_total, weight_total ) if self.enable_lifetime: lifetime_click_through_rate = _click_through_rate_compute( self.click_total, self.weight_total ) return ( lifetime_click_through_rate, windowed_click_through_rate, ) else: return windowed_click_through_rate @torch.inference_mode() def merge_state( self: TWindowedClickThroughRate, metrics: Iterable[TWindowedClickThroughRate] ) -> TWindowedClickThroughRate: """ Merge the metric state with its counterparts from other metric instances. First create tensors of size equal to the sum of all metrics' window sizes. Then, put all tensors to the front and leave the remaining indices zeros. Args: metrics (Iterable[Metric]): metric instances whose states are to be merged. """ merge_max_num_updates = self.max_num_updates for metric in metrics: merge_max_num_updates += metric.max_num_updates cur_click_total = self.windowed_click_total cur_weight_total = self.windowed_weight_total self.windowed_click_total = torch.zeros( self.num_tasks, merge_max_num_updates, dtype=torch.float64, device=self.device, ) self.windowed_weight_total = torch.zeros( self.num_tasks, merge_max_num_updates, dtype=torch.float64, device=self.device, ) cur_size = min(self.total_updates, self.max_num_updates) self.windowed_click_total[:, :cur_size] = cur_click_total[:, :cur_size] self.windowed_weight_total[:, :cur_size] = cur_weight_total[:, :cur_size] idx = cur_size for metric in metrics: if self.enable_lifetime: self.click_total += metric.click_total.to(self.device) self.weight_total += metric.weight_total.to(self.device) cur_size = min(metric.total_updates, metric.max_num_updates) self.windowed_click_total[ :, idx : idx + cur_size ] = metric.windowed_click_total[:, :cur_size] self.windowed_weight_total[ :, idx : idx + cur_size ] = metric.windowed_weight_total[:, :cur_size] idx += cur_size self.total_updates += metric.total_updates self.next_inserted = idx self.next_inserted %= self.max_num_updates return self

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources