Shortcuts

torcheval.metrics.RetrievalPrecision

class torcheval.metrics.RetrievalPrecision(empty_target_action: Union[Literal['neg'], Literal['pos'], Literal['skip'], Literal['err']] = 'neg', k: Optional[int] = None, limit_k_to_size: bool = False, num_queries: int = 1, avg: Optional[Union[Literal['macro'], Literal['none']]] = None, device: Optional[device] = None)[source]

Compute the retrieval precision. Its functional version is torcheval.metrics.functional.retrieval_precision(). (Here, input and target refer to the arguments of update function.)

Parameters:
  • k (int, optional) – the number of elements considered as being retrieved. Only the top (sorted in decreasing order) k elements of input are considered. if k is None, all the input elements are considered.
  • (bool (limit_k_to_size) – False): When set to True, limits k to be at most the length of input, i.e. replaces k by k=min(k, len(input)). This parameter can only be set to True if k is not None.
  • value (default) – False): When set to True, limits k to be at most the length of input, i.e. replaces k by k=min(k, len(input)). This parameter can only be set to True if k is not None.
  • (str (avg) – “neg”): Choose the behaviour of update function when target does not contain at least one positive element: - when ‘neg’: retrieval precision is equal to 0.0, - when ‘pos’: retrieval precision is equal to 1.0, - when ‘skip’: retrieval precision is equal to NaN. - when ‘err’: raise a ValueError.
  • ["neg" (choose among) – “neg”): Choose the behaviour of update function when target does not contain at least one positive element: - when ‘neg’: retrieval precision is equal to 0.0, - when ‘pos’: retrieval precision is equal to 1.0, - when ‘skip’: retrieval precision is equal to NaN. - when ‘err’: raise a ValueError.
  • "pos" – “neg”): Choose the behaviour of update function when target does not contain at least one positive element: - when ‘neg’: retrieval precision is equal to 0.0, - when ‘pos’: retrieval precision is equal to 1.0, - when ‘skip’: retrieval precision is equal to NaN. - when ‘err’: raise a ValueError.
  • "skip" – “neg”): Choose the behaviour of update function when target does not contain at least one positive element: - when ‘neg’: retrieval precision is equal to 0.0, - when ‘pos’: retrieval precision is equal to 1.0, - when ‘skip’: retrieval precision is equal to NaN. - when ‘err’: raise a ValueError.
  • "err"] – “neg”): Choose the behaviour of update function when target does not contain at least one positive element: - when ‘neg’: retrieval precision is equal to 0.0, - when ‘pos’: retrieval precision is equal to 1.0, - when ‘skip’: retrieval precision is equal to NaN. - when ‘err’: raise a ValueError.
  • default – “neg”): Choose the behaviour of update function when target does not contain at least one positive element: - when ‘neg’: retrieval precision is equal to 0.0, - when ‘pos’: retrieval precision is equal to 1.0, - when ‘skip’: retrieval precision is equal to NaN. - when ‘err’: raise a ValueError.
  • (int (num_queries) – 1): If >1, inputs and targets can contain entries related to different queries. An indexes tensor must be passed during updates which associates each input and target to an integer between 0 and num_queries-1. Outputs for each query are computed independently and .compute() will return a tensor of shape (num_queries,).
  • value – 1): If >1, inputs and targets can contain entries related to different queries. An indexes tensor must be passed during updates which associates each input and target to an integer between 0 and num_queries-1. Outputs for each query are computed independently and .compute() will return a tensor of shape (num_queries,).
  • (str – “None”): Choose the averaging method over all queries: - when “none” or None: .compute() returns a tensor of shape (num_queries,), which ith value is equal to the retrieval precision of ith query. - when “macro”: .compute() returns the average retrieval precision over all queries.
  • ["macro" (choose among) – “None”): Choose the averaging method over all queries: - when “none” or None: .compute() returns a tensor of shape (num_queries,), which ith value is equal to the retrieval precision of ith query. - when “macro”: .compute() returns the average retrieval precision over all queries.
  • "none" – “None”): Choose the averaging method over all queries: - when “none” or None: .compute() returns a tensor of shape (num_queries,), which ith value is equal to the retrieval precision of ith query. - when “macro”: .compute() returns the average retrieval precision over all queries.
  • None] – “None”): Choose the averaging method over all queries: - when “none” or None: .compute() returns a tensor of shape (num_queries,), which ith value is equal to the retrieval precision of ith query. - when “macro”: .compute() returns the average retrieval precision over all queries.
  • default – “None”): Choose the averaging method over all queries: - when “none” or None: .compute() returns a tensor of shape (num_queries,), which ith value is equal to the retrieval precision of ith query. - when “macro”: .compute() returns the average retrieval precision over all queries.
  • device – Optional[torch.device]: choose the torch device to be used.

Examples

>>> import torch
>>> from torcheval.metrics import RetrievalPrecision
>>> input = torch.tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = torch.tensor([0, 0, 1, 1, 1, 0, 1])
>>> metric = RetrievalPrecision(k=2)
>>> metric.update(input, target)
>>> metric.compute()
tensor(0.500)
>>> metric = RetrievalPrecision(k=2, num_queries=2)
>>> indexes = torch.tensor([0, 0, 0, 1, 1, 1, 1])
>>> metric.update(input, target, indexes)
>>> metric.compute()
tensor([0.500, 0.500])
>>> target2 = torch.tensor([1, 0, 1, 0, 1, 1, 0])
>>> input2 = torch.tensor([0.4, 0.1, 0.6, 0.8, 0.7, 0.9, 0.3])
>>> indexes = torch.tensor([0, 0, 0, 1, 1, 1, 1])
>>> metric.update(input2, target2, indexes)
# first query: input = [0.2, 0.3, 0.5, 0.4, 0.1, 0.6], target = [0, 0, 1, 1, 0, 1]
# second query: input = [0.1, 0.3, 0.5, 0.2, 0.8, 0.7, 0.9, 0.3], target = [1, 1, 0, 1, 0, 1, 1, 0]
>>> metric.compute()
tensor([1.0, 0.500])
Raises:
  • ValueError – if empty_target_action is not one of “neg”, “pos”, “skip”, “err”.
  • ValueError – if limit_k_to_size is True and k is None.
  • ValueError – if k is not a positive integer.
  • ValueError – if empty_target_action == “err” and self.update is called with a target which entries are all equal to 0.
  • ValueError – if input or target arguments of self.update are Tensors with different dimensions or dimension != 1.
  • ValueError – if num_queris > 1 and argument indexes of function .update() is None.
__init__(empty_target_action: Union[Literal['neg'], Literal['pos'], Literal['skip'], Literal['err']] = 'neg', k: Optional[int] = None, limit_k_to_size: bool = False, num_queries: int = 1, avg: Optional[Union[Literal['macro'], Literal['none']]] = None, device: Optional[device] = None) None[source]

Initialize a metric object and its internal states.

Use self._add_state() to initialize state variables of your metric class. The state variables should be either torch.Tensor, a list of torch.Tensor, or a dictionary with torch.Tensor as values

Methods

__init__([empty_target_action, k, ...]) Initialize a metric object and its internal states.
compute() Implement this method to compute and return the final metric value from state variables.
load_state_dict(state_dict[, strict]) Loads metric state variables from state_dict.
merge_state(metrics) Merge the metric state with its counterparts from other metric instances.
reset() Reset the metric state variables to their default value.
state_dict() Save metric state variables in state_dict.
to(device, *args, **kwargs) Move tensors in metric state variables to device.
update(input, target[, indexes]) Update the metric state with ground truth labels and predictions.
update_single_query(i, input, target)

Attributes

device The last input device of Metric.to().

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources