torcheval.metrics.RetrievalRecall¶
-
class
torcheval.metrics.
RetrievalRecall
(*, empty_target_action: Union[Literal['neg'], Literal['pos'], Literal['skip'], Literal['err']] = 'neg', k: Optional[int] = None, limit_k_to_size: bool = False, num_queries: int = 1, avg: Optional[Union[Literal['macro'], Literal['none']]] = None, device: Optional[device] = None)[source]¶ Compute the retrieval recall. Its functional version is
torcheval.metrics.functional.retrieval_recall()
. (Here, input and target refer to the arguments of update function.)Parameters: - k (int, optional) – the number of elements considered as being retrieved. Only the top (sorted in decreasing order) k elements of input are considered. if k is None, all the input elements are considered.
- (bool (limit_k_to_size) – False): When set to True, limits k to be at most the length of input, i.e. replaces k by k=min(k, len(input)). This parameter can only be set to True if k is not None.
- value (default) – False): When set to True, limits k to be at most the length of input, i.e. replaces k by k=min(k, len(input)). This parameter can only be set to True if k is not None.
- (str (avg) – “neg”): Choose the behaviour of update function when target does not contain at least one positive element: - when ‘neg’: retrieval recall is equal to 0.0, - when ‘pos’: retrieval recall is equal to 1.0, - when ‘skip’: retrieval recall is equal to NaN. - when ‘err’: raise a ValueError.
- ["neg" (choose among) – “neg”): Choose the behaviour of update function when target does not contain at least one positive element: - when ‘neg’: retrieval recall is equal to 0.0, - when ‘pos’: retrieval recall is equal to 1.0, - when ‘skip’: retrieval recall is equal to NaN. - when ‘err’: raise a ValueError.
- "pos" – “neg”): Choose the behaviour of update function when target does not contain at least one positive element: - when ‘neg’: retrieval recall is equal to 0.0, - when ‘pos’: retrieval recall is equal to 1.0, - when ‘skip’: retrieval recall is equal to NaN. - when ‘err’: raise a ValueError.
- "skip" – “neg”): Choose the behaviour of update function when target does not contain at least one positive element: - when ‘neg’: retrieval recall is equal to 0.0, - when ‘pos’: retrieval recall is equal to 1.0, - when ‘skip’: retrieval recall is equal to NaN. - when ‘err’: raise a ValueError.
- "err"] – “neg”): Choose the behaviour of update function when target does not contain at least one positive element: - when ‘neg’: retrieval recall is equal to 0.0, - when ‘pos’: retrieval recall is equal to 1.0, - when ‘skip’: retrieval recall is equal to NaN. - when ‘err’: raise a ValueError.
- default – “neg”): Choose the behaviour of update function when target does not contain at least one positive element: - when ‘neg’: retrieval recall is equal to 0.0, - when ‘pos’: retrieval recall is equal to 1.0, - when ‘skip’: retrieval recall is equal to NaN. - when ‘err’: raise a ValueError.
- (int (num_queries) – 1): If >1, inputs and targets can contain entries related to different queries. An indexes tensor must be passed during updates which associates each input and target to an integer between 0 and num_queries-1. Outputs for each query are computed independently and .compute() will return a tensor of shape (num_queries,).
- value – 1): If >1, inputs and targets can contain entries related to different queries. An indexes tensor must be passed during updates which associates each input and target to an integer between 0 and num_queries-1. Outputs for each query are computed independently and .compute() will return a tensor of shape (num_queries,).
- (str – “None”): Choose the averaging method over all queries: - when “none” or None: .compute() returns a tensor of shape (num_queries,), which ith value is equal to the retrieval recall of ith query. - when “macro”: .compute() returns the average retrieval recall over all queries.
- ["macro" (choose among) – “None”): Choose the averaging method over all queries: - when “none” or None: .compute() returns a tensor of shape (num_queries,), which ith value is equal to the retrieval recall of ith query. - when “macro”: .compute() returns the average retrieval recall over all queries.
- "none" – “None”): Choose the averaging method over all queries: - when “none” or None: .compute() returns a tensor of shape (num_queries,), which ith value is equal to the retrieval recall of ith query. - when “macro”: .compute() returns the average retrieval recall over all queries.
- None] – “None”): Choose the averaging method over all queries: - when “none” or None: .compute() returns a tensor of shape (num_queries,), which ith value is equal to the retrieval recall of ith query. - when “macro”: .compute() returns the average retrieval recall over all queries.
- default – “None”): Choose the averaging method over all queries: - when “none” or None: .compute() returns a tensor of shape (num_queries,), which ith value is equal to the retrieval recall of ith query. - when “macro”: .compute() returns the average retrieval recall over all queries.
- device – Optional[torch.device]: choose the torch device to be used.
Examples
>>> import torch >>> from torcheval.metrics import RetrievalRecall
>>> input = torch.tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) >>> target = torch.tensor([0, 0, 1, 1, 1, 0, 1])
>>> metric = RetrievalRecall(k=2) >>> metric.update(input, target) >>> metric.compute() tensor(0.25)
Raises: - ValueError – if empty_target_action is not one of “neg”, “pos”, “skip”, “err”.
- ValueError – if limit_k_to_size is True and k is None.
- ValueError – if k is not a positive integer.
- ValueError – if empty_target_action == “err” and self.update is called with a target which entries are all equal to 0.
- ValueError – if input or target arguments of self.update are Tensors with different dimensions or dimension != 1.
- ValueError – if num_queries > 1 and argument indexes of function .update() is None.
-
__init__
(*, empty_target_action: Union[Literal['neg'], Literal['pos'], Literal['skip'], Literal['err']] = 'neg', k: Optional[int] = None, limit_k_to_size: bool = False, num_queries: int = 1, avg: Optional[Union[Literal['macro'], Literal['none']]] = None, device: Optional[device] = None) None [source]¶ Initialize a metric object and its internal states.
Use
self._add_state()
to initialize state variables of your metric class. The state variables should be eithertorch.Tensor
, a list oftorch.Tensor
, or a dictionary withtorch.Tensor
as values
Methods
__init__
(*[, empty_target_action, k, ...])Initialize a metric object and its internal states. compute
()Implement this method to compute and return the final metric value from state variables. load_state_dict
(state_dict[, strict])Loads metric state variables from state_dict. merge_state
(metrics)Merge the metric state with its counterparts from other metric instances. reset
()Reset the metric state variables to their default value. state_dict
()Save metric state variables in state_dict. to
(device, *args, **kwargs)Move tensors in metric state variables to device. update
(input, target[, indexes])Update the metric state with ground truth labels and predictions. update_single_query
(i, input, target)Attributes
device
The last input device of Metric.to()
.