fromabcimportABCMeta,abstractmethodfromcollectionsimportnamedtuplefromtypingimportAny,Callable,List,Mapping,Optional,Sequence,Tuple,Unionimporttorchfromignite.exceptionsimportNotComputableError# These decorators helps with distributed settingsfromignite.metrics.metricimportMetric,reinit__is_reduced,sync_all_reducefromignite.metrics.nlp.utilsimportlcs,ngrams__all__=["Rouge","RougeN","RougeL"]classScore(namedtuple("Score",["match","candidate","reference"])):r""" Computes precision and recall for given matches, candidate and reference lengths. """defprecision(self)->float:""" Calculates precision. """returnself.match/self.candidateifself.candidate>0else0defrecall(self)->float:""" Calculates recall. """returnself.match/self.referenceifself.reference>0else0defcompute_ngram_scores(candidate:Sequence[Any],reference:Sequence[Any],n:int=4)->Score:""" Compute the score based on ngram co-occurence of sequences of items Args: candidate: candidate sequence of items reference: reference sequence of items n: ngram order Returns: The score containing the number of ngram co-occurences .. versionadded:: 0.4.5 """# ngrams of the candidatecandidate_counter=ngrams(candidate,n)# ngrams of the referencesreference_counter=ngrams(reference,n)# ngram co-occurences in the candidate and the referencesmatch_counters=candidate_counter&reference_counter# the score is defined using FractionreturnScore(match=sum(match_counters.values()),candidate=sum(candidate_counter.values()),reference=sum(reference_counter.values()),)defcompute_lcs_scores(candidate:Sequence[Any],reference:Sequence[Any])->Score:""" Compute the score based on longest common subsequence of sequences of items Args: candidate: candidate sequence of items reference: reference sequence of items Returns: The score containing the length of longest common subsequence .. versionadded:: 0.4.5 """# lcs of candidate and referencematch=lcs(candidate,reference)# the score is defined using FractionreturnScore(match=match,candidate=len(candidate),reference=len(reference))classMultiRefReducer(metaclass=ABCMeta):r""" Reducer interface for multi-reference """@abstractmethoddef__call__(self,scores:Sequence[Score])->Score:passclassMultiRefAverageReducer(MultiRefReducer):r""" Reducer for averaging the scores """def__call__(self,scores:Sequence[Score])->Score:match=sum([score.matchforscoreinscores])candidate=sum([score.candidateforscoreinscores])reference=sum([score.referenceforscoreinscores])returnScore(match=match,candidate=candidate,reference=reference)classMultiRefBestReducer(MultiRefReducer):r""" Reducer for selecting the best score """def__call__(self,scores:Sequence[Score])->Score:returnmax(scores,key=lambdax:x.recall())class_BaseRouge(Metric):r""" Rouge interface for Rouge-L and Rouge-N """_state_dict_all_req_keys=("_recall","_precision","_fmeasure","_num_examples")def__init__(self,multiref:str="average",alpha:float=0,output_transform:Callable=lambdax:x,device:Union[str,torch.device]=torch.device("cpu"),)->None:super(_BaseRouge,self).__init__(output_transform=output_transform,device=device)self._alpha=alphaifnot0<=self._alpha<=1:raiseValueError(f"alpha must be in interval [0, 1] (got : {self._alpha})")self._multiref=multirefvalid_multiref=["best","average"]ifself._multirefnotinvalid_multiref:raiseValueError(f"multiref : valid values are {valid_multiref} (got : {self._multiref})")self._mutliref_reducer=self._get_multiref_reducer()def_get_multiref_reducer(self)->MultiRefReducer:ifself._multiref=="average":returnMultiRefAverageReducer()returnMultiRefBestReducer()@reinit__is_reduceddefreset(self)->None:self._recall=0.0self._precision=0.0self._fmeasure=0.0self._num_examples=0@reinit__is_reduceddefupdate(self,output:Tuple[Sequence[Sequence[Any]],Sequence[Sequence[Sequence[Any]]]])->None:candidates,references=outputfor_candidate,_referenceinzip(candidates,references):multiref_scores=[self._compute_score(candidate=_candidate,reference=_ref)for_refin_reference]score=self._mutliref_reducer(multiref_scores)precision=score.precision()recall=score.recall()self._precision+=precisionself._recall+=recallprecision_recall=precision*recallifprecision_recall>0:# avoid zero divisionself._fmeasure+=precision_recall/((1-self._alpha)*precision+self._alpha*recall)self._num_examples+=1@sync_all_reduce("_precision","_recall","_fmeasure","_num_examples")defcompute(self)->Mapping:ifself._num_examples==0:raiseNotComputableError("Rouge metric must have at least one example before be computed")return{f"{self._metric_name()}-P":float(self._precision/self._num_examples),f"{self._metric_name()}-R":float(self._recall/self._num_examples),f"{self._metric_name()}-F":float(self._fmeasure/self._num_examples),}@abstractmethoddef_compute_score(self,candidate:Sequence[Any],reference:Sequence[Any])->Score:pass@abstractmethoddef_metric_name(self)->str:pass
[docs]classRougeN(_BaseRouge):r"""Calculates the Rouge-N score. The Rouge-N is based on the ngram co-occurences of candidates and references. More details can be found in `Lin 2004`__. __ https://www.aclweb.org/anthology/W04-1013.pdf - ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. - `y_pred` (list(list(str))) must be a sequence of tokens. - `y` (list(list(list(str))) must be a list of sequence of tokens. Args: ngram: ngram order (default: 4). multiref: reduces scores for multi references. Valid values are "best" and "average" (default: "average"). alpha: controls the importance between recall and precision (alpha -> 0: recall is more important, alpha -> 1: precision is more important) output_transform: a callable that is used to transform the :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. device: specifies which device updates are accumulated on. Setting the metric's device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By default, CPU. Examples: For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. .. testcode:: from ignite.metrics import RougeN m = RougeN(ngram=2, multiref="best") candidate = "the cat is not there".split() references = [ "the cat is on the mat".split(), "there is a cat on the mat".split() ] m.update(([candidate], [references])) print(m.compute()) .. testoutput:: {'Rouge-2-P': 0.5, 'Rouge-2-R': 0.4, 'Rouge-2-F': 0.4} .. versionadded:: 0.4.5 """def__init__(self,ngram:int=4,multiref:str="average",alpha:float=0,output_transform:Callable=lambdax:x,device:Union[str,torch.device]=torch.device("cpu"),):super(RougeN,self).__init__(multiref=multiref,alpha=alpha,output_transform=output_transform,device=device)self._ngram=ngramifself._ngram<1:raiseValueError(f"ngram order must be greater than zero (got : {self._ngram})")def_compute_score(self,candidate:Sequence[Any],reference:Sequence[Any])->Score:returncompute_ngram_scores(candidate=candidate,reference=reference,n=self._ngram)def_metric_name(self)->str:returnf"Rouge-{self._ngram}"
[docs]classRougeL(_BaseRouge):r"""Calculates the Rouge-L score. The Rouge-L is based on the length of the longest common subsequence of candidates and references. More details can be found in `Lin 2004`__. __ https://www.aclweb.org/anthology/W04-1013.pdf - ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. - `y_pred` (list(list(str))) must be a sequence of tokens. - `y` (list(list(list(str))) must be a list of sequence of tokens. Args: multiref: reduces scores for multi references. Valid values are "best" and "average" (default: "average"). alpha: controls the importance between recall and precision (alpha -> 0: recall is more important, alpha -> 1: precision is more important) output_transform: a callable that is used to transform the :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. device: specifies which device updates are accumulated on. Setting the metric's device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By default, CPU. Examples: For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. .. testcode:: from ignite.metrics import RougeL m = RougeL(multiref="best") candidate = "the cat is not there".split() references = [ "the cat is on the mat".split(), "there is a cat on the mat".split() ] m.update(([candidate], [references])) print(m.compute()) .. testoutput:: {'Rouge-L-P': 0.6, 'Rouge-L-R': 0.5, 'Rouge-L-F': 0.5} .. versionadded:: 0.4.5 """def__init__(self,multiref:str="average",alpha:float=0,output_transform:Callable=lambdax:x,device:Union[str,torch.device]=torch.device("cpu"),):super(RougeL,self).__init__(multiref=multiref,alpha=alpha,output_transform=output_transform,device=device)def_compute_score(self,candidate:Sequence[Any],reference:Sequence[Any])->Score:returncompute_lcs_scores(candidate=candidate,reference=reference)def_metric_name(self)->str:return"Rouge-L"
[docs]classRouge(Metric):r"""Calculates the Rouge score for multiples Rouge-N and Rouge-L metrics. More details can be found in `Lin 2004`__. __ https://www.aclweb.org/anthology/W04-1013.pdf - ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. - `y_pred` (list(list(str))) must be a sequence of tokens. - `y` (list(list(list(str))) must be a list of sequence of tokens. Args: variants: set of metrics computed. Valid inputs are "L" and integer 1 <= n <= 9. multiref: reduces scores for multi references. Valid values are "best" and "average" (default: "average"). alpha: controls the importance between recall and precision (alpha -> 0: recall is more important, alpha -> 1: precision is more important) output_transform: a callable that is used to transform the :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. device: specifies which device updates are accumulated on. Setting the metric's device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By default, CPU. Examples: For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. .. testcode:: from ignite.metrics import Rouge m = Rouge(variants=["L", 2], multiref="best") candidate = "the cat is not there".split() references = [ "the cat is on the mat".split(), "there is a cat on the mat".split() ] m.update(([candidate], [references])) print(m.compute()) .. testoutput:: {'Rouge-L-P': 0.6, 'Rouge-L-R': 0.5, 'Rouge-L-F': 0.5, 'Rouge-2-P': 0.5, 'Rouge-2-R': 0.4, 'Rouge-2-F': 0.4} .. versionadded:: 0.4.5 .. versionchanged:: 0.4.7 ``update`` method has changed and now works on batch of inputs. """_state_dict_all_req_keys=("internal_metrics",)def__init__(self,variants:Optional[Sequence[Union[str,int]]]=None,multiref:str="average",alpha:float=0,output_transform:Callable=lambdax:x,device:Union[str,torch.device]=torch.device("cpu"),):ifvariantsisNoneorlen(variants)==0:variants=[1,2,4,"L"]self.internal_metrics:List[_BaseRouge]=[]forminvariants:variant:Optional[_BaseRouge]=Noneifisinstance(m,str)andm=="L":variant=RougeL(multiref=multiref,alpha=alpha,output_transform=output_transform,device=device)elifisinstance(m,int):variant=RougeN(ngram=m,multiref=multiref,alpha=alpha,output_transform=output_transform,device=device)else:raiseValueError("variant must be 'L' or integer greater to zero")self.internal_metrics.append(variant)super(Rouge,self).__init__(output_transform=output_transform,device=device)