importmathimportwarningsfromfunctoolsimporttotal_orderingfromtypingimportCallable,Dict,Tuple,Typeimporttorchfromtorchimportinffrom.bernoulliimportBernoullifrom.betaimportBetafrom.binomialimportBinomialfrom.categoricalimportCategoricalfrom.cauchyimportCauchyfrom.continuous_bernoulliimportContinuousBernoullifrom.dirichletimportDirichletfrom.distributionimportDistributionfrom.exp_familyimportExponentialFamilyfrom.exponentialimportExponentialfrom.gammaimportGammafrom.geometricimportGeometricfrom.gumbelimportGumbelfrom.half_normalimportHalfNormalfrom.independentimportIndependentfrom.laplaceimportLaplacefrom.lowrank_multivariate_normalimport(_batch_lowrank_logdet,_batch_lowrank_mahalanobis,LowRankMultivariateNormal,)from.multivariate_normalimport_batch_mahalanobis,MultivariateNormalfrom.normalimportNormalfrom.one_hot_categoricalimportOneHotCategoricalfrom.paretoimportParetofrom.poissonimportPoissonfrom.transformed_distributionimportTransformedDistributionfrom.uniformimportUniformfrom.utilsimport_sum_rightmost,euler_constantas_euler_gamma_KL_REGISTRY=({})# Source of truth mapping a few general (type, type) pairs to functions._KL_MEMOIZE:Dict[Tuple[Type,Type],Callable]={}# Memoized version mapping many specific (type, type) pairs to functions.__all__=["register_kl","kl_divergence"]
[docs]defregister_kl(type_p,type_q):""" Decorator to register a pairwise function with :meth:`kl_divergence`. Usage:: @register_kl(Normal, Normal) def kl_normal_normal(p, q): # insert implementation here Lookup returns the most specific (type,type) match ordered by subclass. If the match is ambiguous, a `RuntimeWarning` is raised. For example to resolve the ambiguous situation:: @register_kl(BaseP, DerivedQ) def kl_version1(p, q): ... @register_kl(DerivedP, BaseQ) def kl_version2(p, q): ... you should register a third most-specific implementation, e.g.:: register_kl(DerivedP, DerivedQ)(kl_version1) # Break the tie. Args: type_p (type): A subclass of :class:`~torch.distributions.Distribution`. type_q (type): A subclass of :class:`~torch.distributions.Distribution`. """ifnotisinstance(type_p,type)andissubclass(type_p,Distribution):raiseTypeError(f"Expected type_p to be a Distribution subclass but got {type_p}")ifnotisinstance(type_q,type)andissubclass(type_q,Distribution):raiseTypeError(f"Expected type_q to be a Distribution subclass but got {type_q}")defdecorator(fun):_KL_REGISTRY[type_p,type_q]=fun_KL_MEMOIZE.clear()# reset since lookup order may have changedreturnfunreturndecorator
@total_orderingclass_Match:__slots__=["types"]def__init__(self,*types):self.types=typesdef__eq__(self,other):returnself.types==other.typesdef__le__(self,other):forx,yinzip(self.types,other.types):ifnotissubclass(x,y):returnFalseifxisnoty:breakreturnTruedef_dispatch_kl(type_p,type_q):""" Find the most specific approximate match, assuming single inheritance. """matches=[(super_p,super_q)forsuper_p,super_qin_KL_REGISTRYifissubclass(type_p,super_p)andissubclass(type_q,super_q)]ifnotmatches:returnNotImplemented# Check that the left- and right- lexicographic orders agree.# mypy isn't smart enough to know that _Match implements __lt__# see: https://github.com/python/typing/issues/760#issuecomment-710670503left_p,left_q=min(_Match(*m)forminmatches).types# type: ignore[type-var]right_q,right_p=min(_Match(*reversed(m))forminmatches).types# type: ignore[type-var]left_fun=_KL_REGISTRY[left_p,left_q]right_fun=_KL_REGISTRY[right_p,right_q]ifleft_funisnotright_fun:warnings.warn("Ambiguous kl_divergence({}, {}). Please register_kl({}, {})".format(type_p.__name__,type_q.__name__,left_p.__name__,right_q.__name__),RuntimeWarning,)returnleft_fundef_infinite_like(tensor):""" Helper function for obtaining infinite KL Divergence throughout """returntorch.full_like(tensor,inf)def_x_log_x(tensor):""" Utility function for calculating x log x """returntensor*tensor.log()def_batch_trace_XXT(bmat):""" Utility function for calculating the trace of XX^{T} with X having arbitrary trailing batch dimensions """n=bmat.size(-1)m=bmat.size(-2)flat_trace=bmat.reshape(-1,m*n).pow(2).sum(-1)returnflat_trace.reshape(bmat.shape[:-2])
[docs]defkl_divergence(p:Distribution,q:Distribution)->torch.Tensor:r""" Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions. .. math:: KL(p \| q) = \int p(x) \log\frac {p(x)} {q(x)} \,dx Args: p (Distribution): A :class:`~torch.distributions.Distribution` object. q (Distribution): A :class:`~torch.distributions.Distribution` object. Returns: Tensor: A batch of KL divergences of shape `batch_shape`. Raises: NotImplementedError: If the distribution types have not been registered via :meth:`register_kl`. """try:fun=_KL_MEMOIZE[type(p),type(q)]exceptKeyError:fun=_dispatch_kl(type(p),type(q))_KL_MEMOIZE[type(p),type(q)]=funiffunisNotImplemented:raiseNotImplementedError(f"No KL(p || q) is implemented for p type {p.__class__.__name__} and q type {q.__class__.__name__}")returnfun(p,q)
################################################################################# KL Divergence Implementations################################################################################# Same distributions@register_kl(Bernoulli,Bernoulli)def_kl_bernoulli_bernoulli(p,q):t1=p.probs*(torch.nn.functional.softplus(-q.logits)-torch.nn.functional.softplus(-p.logits))t1[q.probs==0]=inft1[p.probs==0]=0t2=(1-p.probs)*(torch.nn.functional.softplus(q.logits)-torch.nn.functional.softplus(p.logits))t2[q.probs==1]=inft2[p.probs==1]=0returnt1+t2@register_kl(Beta,Beta)def_kl_beta_beta(p,q):sum_params_p=p.concentration1+p.concentration0sum_params_q=q.concentration1+q.concentration0t1=q.concentration1.lgamma()+q.concentration0.lgamma()+(sum_params_p).lgamma()t2=p.concentration1.lgamma()+p.concentration0.lgamma()+(sum_params_q).lgamma()t3=(p.concentration1-q.concentration1)*torch.digamma(p.concentration1)t4=(p.concentration0-q.concentration0)*torch.digamma(p.concentration0)t5=(sum_params_q-sum_params_p)*torch.digamma(sum_params_p)returnt1-t2+t3+t4+t5@register_kl(Binomial,Binomial)def_kl_binomial_binomial(p,q):# from https://math.stackexchange.com/questions/2214993/# kullback-leibler-divergence-for-binomial-distributions-p-and-qif(p.total_count<q.total_count).any():raiseNotImplementedError("KL between Binomials where q.total_count > p.total_count is not implemented")kl=p.total_count*(p.probs*(p.logits-q.logits)+(-p.probs).log1p()-(-q.probs).log1p())inf_idxs=p.total_count>q.total_countkl[inf_idxs]=_infinite_like(kl[inf_idxs])returnkl@register_kl(Categorical,Categorical)def_kl_categorical_categorical(p,q):t=p.probs*(p.logits-q.logits)t[(q.probs==0).expand_as(t)]=inft[(p.probs==0).expand_as(t)]=0returnt.sum(-1)@register_kl(ContinuousBernoulli,ContinuousBernoulli)def_kl_continuous_bernoulli_continuous_bernoulli(p,q):t1=p.mean*(p.logits-q.logits)t2=p._cont_bern_log_norm()+torch.log1p(-p.probs)t3=-q._cont_bern_log_norm()-torch.log1p(-q.probs)returnt1+t2+t3@register_kl(Dirichlet,Dirichlet)def_kl_dirichlet_dirichlet(p,q):# From http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/sum_p_concentration=p.concentration.sum(-1)sum_q_concentration=q.concentration.sum(-1)t1=sum_p_concentration.lgamma()-sum_q_concentration.lgamma()t2=(p.concentration.lgamma()-q.concentration.lgamma()).sum(-1)t3=p.concentration-q.concentrationt4=p.concentration.digamma()-sum_p_concentration.digamma().unsqueeze(-1)returnt1-t2+(t3*t4).sum(-1)@register_kl(Exponential,Exponential)def_kl_exponential_exponential(p,q):rate_ratio=q.rate/p.ratet1=-rate_ratio.log()returnt1+rate_ratio-1@register_kl(ExponentialFamily,ExponentialFamily)def_kl_expfamily_expfamily(p,q):ifnottype(p)==type(q):raiseNotImplementedError("The cross KL-divergence between different exponential families cannot \ be computed using Bregman divergences")p_nparams=[np.detach().requires_grad_()fornpinp._natural_params]q_nparams=q._natural_paramslg_normal=p._log_normalizer(*p_nparams)gradients=torch.autograd.grad(lg_normal.sum(),p_nparams,create_graph=True)result=q._log_normalizer(*q_nparams)-lg_normalforpnp,qnp,ginzip(p_nparams,q_nparams,gradients):term=(qnp-pnp)*gresult-=_sum_rightmost(term,len(q.event_shape))returnresult@register_kl(Gamma,Gamma)def_kl_gamma_gamma(p,q):t1=q.concentration*(p.rate/q.rate).log()t2=torch.lgamma(q.concentration)-torch.lgamma(p.concentration)t3=(p.concentration-q.concentration)*torch.digamma(p.concentration)t4=(q.rate-p.rate)*(p.concentration/p.rate)returnt1+t2+t3+t4@register_kl(Gumbel,Gumbel)def_kl_gumbel_gumbel(p,q):ct1=p.scale/q.scalect2=q.loc/q.scalect3=p.loc/q.scalet1=-ct1.log()-ct2+ct3t2=ct1*_euler_gammat3=torch.exp(ct2+(1+ct1).lgamma()-ct3)returnt1+t2+t3-(1+_euler_gamma)@register_kl(Geometric,Geometric)def_kl_geometric_geometric(p,q):return-p.entropy()-torch.log1p(-q.probs)/p.probs-q.logits@register_kl(HalfNormal,HalfNormal)def_kl_halfnormal_halfnormal(p,q):return_kl_normal_normal(p.base_dist,q.base_dist)@register_kl(Laplace,Laplace)def_kl_laplace_laplace(p,q):# From http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdfscale_ratio=p.scale/q.scaleloc_abs_diff=(p.loc-q.loc).abs()t1=-scale_ratio.log()t2=loc_abs_diff/q.scalet3=scale_ratio*torch.exp(-loc_abs_diff/p.scale)returnt1+t2+t3-1@register_kl(LowRankMultivariateNormal,LowRankMultivariateNormal)def_kl_lowrankmultivariatenormal_lowrankmultivariatenormal(p,q):ifp.event_shape!=q.event_shape:raiseValueError("KL-divergence between two Low Rank Multivariate Normals with\ different event shapes cannot be computed")term1=_batch_lowrank_logdet(q._unbroadcasted_cov_factor,q._unbroadcasted_cov_diag,q._capacitance_tril)-_batch_lowrank_logdet(p._unbroadcasted_cov_factor,p._unbroadcasted_cov_diag,p._capacitance_tril)term3=_batch_lowrank_mahalanobis(q._unbroadcasted_cov_factor,q._unbroadcasted_cov_diag,q.loc-p.loc,q._capacitance_tril,)# Expands term2 according to# inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ (pW @ pW.T + pD)# = [inv(qD) - A.T @ A] @ (pD + pW @ pW.T)qWt_qDinv=q._unbroadcasted_cov_factor.mT/q._unbroadcasted_cov_diag.unsqueeze(-2)A=torch.linalg.solve_triangular(q._capacitance_tril,qWt_qDinv,upper=False)term21=(p._unbroadcasted_cov_diag/q._unbroadcasted_cov_diag).sum(-1)term22=_batch_trace_XXT(p._unbroadcasted_cov_factor*q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1))term23=_batch_trace_XXT(A*p._unbroadcasted_cov_diag.sqrt().unsqueeze(-2))term24=_batch_trace_XXT(A.matmul(p._unbroadcasted_cov_factor))term2=term21+term22-term23-term24return0.5*(term1+term2+term3-p.event_shape[0])@register_kl(MultivariateNormal,LowRankMultivariateNormal)def_kl_multivariatenormal_lowrankmultivariatenormal(p,q):ifp.event_shape!=q.event_shape:raiseValueError("KL-divergence between two (Low Rank) Multivariate Normals with\ different event shapes cannot be computed")term1=_batch_lowrank_logdet(q._unbroadcasted_cov_factor,q._unbroadcasted_cov_diag,q._capacitance_tril)-2*p._unbroadcasted_scale_tril.diagonal(dim1=-2,dim2=-1).log().sum(-1)term3=_batch_lowrank_mahalanobis(q._unbroadcasted_cov_factor,q._unbroadcasted_cov_diag,q.loc-p.loc,q._capacitance_tril,)# Expands term2 according to# inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ p_tril @ p_tril.T# = [inv(qD) - A.T @ A] @ p_tril @ p_tril.TqWt_qDinv=q._unbroadcasted_cov_factor.mT/q._unbroadcasted_cov_diag.unsqueeze(-2)A=torch.linalg.solve_triangular(q._capacitance_tril,qWt_qDinv,upper=False)term21=_batch_trace_XXT(p._unbroadcasted_scale_tril*q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1))term22=_batch_trace_XXT(A.matmul(p._unbroadcasted_scale_tril))term2=term21-term22return0.5*(term1+term2+term3-p.event_shape[0])@register_kl(LowRankMultivariateNormal,MultivariateNormal)def_kl_lowrankmultivariatenormal_multivariatenormal(p,q):ifp.event_shape!=q.event_shape:raiseValueError("KL-divergence between two (Low Rank) Multivariate Normals with\ different event shapes cannot be computed")term1=2*q._unbroadcasted_scale_tril.diagonal(dim1=-2,dim2=-1).log().sum(-1)-_batch_lowrank_logdet(p._unbroadcasted_cov_factor,p._unbroadcasted_cov_diag,p._capacitance_tril)term3=_batch_mahalanobis(q._unbroadcasted_scale_tril,(q.loc-p.loc))# Expands term2 according to# inv(qcov) @ pcov = inv(q_tril @ q_tril.T) @ (pW @ pW.T + pD)combined_batch_shape=torch._C._infer_size(q._unbroadcasted_scale_tril.shape[:-2],p._unbroadcasted_cov_factor.shape[:-2])n=p.event_shape[0]q_scale_tril=q._unbroadcasted_scale_tril.expand(combined_batch_shape+(n,n))p_cov_factor=p._unbroadcasted_cov_factor.expand(combined_batch_shape+(n,p.cov_factor.size(-1)))p_cov_diag=torch.diag_embed(p._unbroadcasted_cov_diag.sqrt()).expand(combined_batch_shape+(n,n))term21=_batch_trace_XXT(torch.linalg.solve_triangular(q_scale_tril,p_cov_factor,upper=False))term22=_batch_trace_XXT(torch.linalg.solve_triangular(q_scale_tril,p_cov_diag,upper=False))term2=term21+term22return0.5*(term1+term2+term3-p.event_shape[0])@register_kl(MultivariateNormal,MultivariateNormal)def_kl_multivariatenormal_multivariatenormal(p,q):# From https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Kullback%E2%80%93Leibler_divergenceifp.event_shape!=q.event_shape:raiseValueError("KL-divergence between two Multivariate Normals with\ different event shapes cannot be computed")half_term1=q._unbroadcasted_scale_tril.diagonal(dim1=-2,dim2=-1).log().sum(-1)-p._unbroadcasted_scale_tril.diagonal(dim1=-2,dim2=-1).log().sum(-1)combined_batch_shape=torch._C._infer_size(q._unbroadcasted_scale_tril.shape[:-2],p._unbroadcasted_scale_tril.shape[:-2])n=p.event_shape[0]q_scale_tril=q._unbroadcasted_scale_tril.expand(combined_batch_shape+(n,n))p_scale_tril=p._unbroadcasted_scale_tril.expand(combined_batch_shape+(n,n))term2=_batch_trace_XXT(torch.linalg.solve_triangular(q_scale_tril,p_scale_tril,upper=False))term3=_batch_mahalanobis(q._unbroadcasted_scale_tril,(q.loc-p.loc))returnhalf_term1+0.5*(term2+term3-n)@register_kl(Normal,Normal)def_kl_normal_normal(p,q):var_ratio=(p.scale/q.scale).pow(2)t1=((p.loc-q.loc)/q.scale).pow(2)return0.5*(var_ratio+t1-1-var_ratio.log())@register_kl(OneHotCategorical,OneHotCategorical)def_kl_onehotcategorical_onehotcategorical(p,q):return_kl_categorical_categorical(p._categorical,q._categorical)@register_kl(Pareto,Pareto)def_kl_pareto_pareto(p,q):# From http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdfscale_ratio=p.scale/q.scalealpha_ratio=q.alpha/p.alphat1=q.alpha*scale_ratio.log()t2=-alpha_ratio.log()result=t1+t2+alpha_ratio-1result[p.support.lower_bound<q.support.lower_bound]=infreturnresult@register_kl(Poisson,Poisson)def_kl_poisson_poisson(p,q):returnp.rate*(p.rate.log()-q.rate.log())-(p.rate-q.rate)@register_kl(TransformedDistribution,TransformedDistribution)def_kl_transformed_transformed(p,q):ifp.transforms!=q.transforms:raiseNotImplementedErrorifp.event_shape!=q.event_shape:raiseNotImplementedErrorreturnkl_divergence(p.base_dist,q.base_dist)@register_kl(Uniform,Uniform)def_kl_uniform_uniform(p,q):result=((q.high-q.low)/(p.high-p.low)).log()result[(q.low>p.low)|(q.high<p.high)]=infreturnresult# Different distributions@register_kl(Bernoulli,Poisson)def_kl_bernoulli_poisson(p,q):return-p.entropy()-(p.probs*q.rate.log()-q.rate)@register_kl(Beta,ContinuousBernoulli)def_kl_beta_continuous_bernoulli(p,q):return(-p.entropy()-p.mean*q.logits-torch.log1p(-q.probs)-q._cont_bern_log_norm())@register_kl(Beta,Pareto)def_kl_beta_infinity(p,q):return_infinite_like(p.concentration1)@register_kl(Beta,Exponential)def_kl_beta_exponential(p,q):return(-p.entropy()-q.rate.log()+q.rate*(p.concentration1/(p.concentration1+p.concentration0)))@register_kl(Beta,Gamma)def_kl_beta_gamma(p,q):t1=-p.entropy()t2=q.concentration.lgamma()-q.concentration*q.rate.log()t3=(q.concentration-1)*(p.concentration1.digamma()-(p.concentration1+p.concentration0).digamma())t4=q.rate*p.concentration1/(p.concentration1+p.concentration0)returnt1+t2-t3+t4# TODO: Add Beta-Laplace KL Divergence@register_kl(Beta,Normal)def_kl_beta_normal(p,q):E_beta=p.concentration1/(p.concentration1+p.concentration0)var_normal=q.scale.pow(2)t1=-p.entropy()t2=0.5*(var_normal*2*math.pi).log()t3=(E_beta*(1-E_beta)/(p.concentration1+p.concentration0+1)+E_beta.pow(2))*0.5t4=q.loc*E_betat5=q.loc.pow(2)*0.5returnt1+t2+(t3-t4+t5)/var_normal@register_kl(Beta,Uniform)def_kl_beta_uniform(p,q):result=-p.entropy()+(q.high-q.low).log()result[(q.low>p.support.lower_bound)|(q.high<p.support.upper_bound)]=infreturnresult# Note that the KL between a ContinuousBernoulli and Beta has no closed form@register_kl(ContinuousBernoulli,Pareto)def_kl_continuous_bernoulli_infinity(p,q):return_infinite_like(p.probs)@register_kl(ContinuousBernoulli,Exponential)def_kl_continuous_bernoulli_exponential(p,q):return-p.entropy()-torch.log(q.rate)+q.rate*p.mean# Note that the KL between a ContinuousBernoulli and Gamma has no closed form# TODO: Add ContinuousBernoulli-Laplace KL Divergence@register_kl(ContinuousBernoulli,Normal)def_kl_continuous_bernoulli_normal(p,q):t1=-p.entropy()t2=0.5*(math.log(2.0*math.pi)+torch.square(q.loc/q.scale))+torch.log(q.scale)t3=(p.variance+torch.square(p.mean)-2.0*q.loc*p.mean)/(2.0*torch.square(q.scale))returnt1+t2+t3@register_kl(ContinuousBernoulli,Uniform)def_kl_continuous_bernoulli_uniform(p,q):result=-p.entropy()+(q.high-q.low).log()returntorch.where(torch.max(torch.ge(q.low,p.support.lower_bound),torch.le(q.high,p.support.upper_bound),),torch.ones_like(result)*inf,result,)@register_kl(Exponential,Beta)@register_kl(Exponential,ContinuousBernoulli)@register_kl(Exponential,Pareto)@register_kl(Exponential,Uniform)def_kl_exponential_infinity(p,q):return_infinite_like(p.rate)@register_kl(Exponential,Gamma)def_kl_exponential_gamma(p,q):ratio=q.rate/p.ratet1=-q.concentration*torch.log(ratio)return(t1+ratio+q.concentration.lgamma()+q.concentration*_euler_gamma-(1+_euler_gamma))@register_kl(Exponential,Gumbel)def_kl_exponential_gumbel(p,q):scale_rate_prod=p.rate*q.scaleloc_scale_ratio=q.loc/q.scalet1=scale_rate_prod.log()-1t2=torch.exp(loc_scale_ratio)*scale_rate_prod/(scale_rate_prod+1)t3=scale_rate_prod.reciprocal()returnt1-loc_scale_ratio+t2+t3# TODO: Add Exponential-Laplace KL Divergence@register_kl(Exponential,Normal)def_kl_exponential_normal(p,q):var_normal=q.scale.pow(2)rate_sqr=p.rate.pow(2)t1=0.5*torch.log(rate_sqr*var_normal*2*math.pi)t2=rate_sqr.reciprocal()t3=q.loc/p.ratet4=q.loc.pow(2)*0.5returnt1-1+(t2-t3+t4)/var_normal@register_kl(Gamma,Beta)@register_kl(Gamma,ContinuousBernoulli)@register_kl(Gamma,Pareto)@register_kl(Gamma,Uniform)def_kl_gamma_infinity(p,q):return_infinite_like(p.concentration)@register_kl(Gamma,Exponential)def_kl_gamma_exponential(p,q):return-p.entropy()-q.rate.log()+q.rate*p.concentration/p.rate@register_kl(Gamma,Gumbel)def_kl_gamma_gumbel(p,q):beta_scale_prod=p.rate*q.scaleloc_scale_ratio=q.loc/q.scalet1=((p.concentration-1)*p.concentration.digamma()-p.concentration.lgamma()-p.concentration)t2=beta_scale_prod.log()+p.concentration/beta_scale_prodt3=(torch.exp(loc_scale_ratio)*(1+beta_scale_prod.reciprocal()).pow(-p.concentration)-loc_scale_ratio)returnt1+t2+t3# TODO: Add Gamma-Laplace KL Divergence@register_kl(Gamma,Normal)def_kl_gamma_normal(p,q):var_normal=q.scale.pow(2)beta_sqr=p.rate.pow(2)t1=(0.5*torch.log(beta_sqr*var_normal*2*math.pi)-p.concentration-p.concentration.lgamma())t2=0.5*(p.concentration.pow(2)+p.concentration)/beta_sqrt3=q.loc*p.concentration/p.ratet4=0.5*q.loc.pow(2)return(t1+(p.concentration-1)*p.concentration.digamma()+(t2-t3+t4)/var_normal)@register_kl(Gumbel,Beta)@register_kl(Gumbel,ContinuousBernoulli)@register_kl(Gumbel,Exponential)@register_kl(Gumbel,Gamma)@register_kl(Gumbel,Pareto)@register_kl(Gumbel,Uniform)def_kl_gumbel_infinity(p,q):return_infinite_like(p.loc)# TODO: Add Gumbel-Laplace KL Divergence@register_kl(Gumbel,Normal)def_kl_gumbel_normal(p,q):param_ratio=p.scale/q.scalet1=(param_ratio/math.sqrt(2*math.pi)).log()t2=(math.pi*param_ratio*0.5).pow(2)/3t3=((p.loc+p.scale*_euler_gamma-q.loc)/q.scale).pow(2)*0.5return-t1+t2+t3-(_euler_gamma+1)@register_kl(Laplace,Beta)@register_kl(Laplace,ContinuousBernoulli)@register_kl(Laplace,Exponential)@register_kl(Laplace,Gamma)@register_kl(Laplace,Pareto)@register_kl(Laplace,Uniform)def_kl_laplace_infinity(p,q):return_infinite_like(p.loc)@register_kl(Laplace,Normal)def_kl_laplace_normal(p,q):var_normal=q.scale.pow(2)scale_sqr_var_ratio=p.scale.pow(2)/var_normalt1=0.5*torch.log(2*scale_sqr_var_ratio/math.pi)t2=0.5*p.loc.pow(2)t3=p.loc*q.loct4=0.5*q.loc.pow(2)return-t1+scale_sqr_var_ratio+(t2-t3+t4)/var_normal-1@register_kl(Normal,Beta)@register_kl(Normal,ContinuousBernoulli)@register_kl(Normal,Exponential)@register_kl(Normal,Gamma)@register_kl(Normal,Pareto)@register_kl(Normal,Uniform)def_kl_normal_infinity(p,q):return_infinite_like(p.loc)@register_kl(Normal,Gumbel)def_kl_normal_gumbel(p,q):mean_scale_ratio=p.loc/q.scalevar_scale_sqr_ratio=(p.scale/q.scale).pow(2)loc_scale_ratio=q.loc/q.scalet1=var_scale_sqr_ratio.log()*0.5t2=mean_scale_ratio-loc_scale_ratiot3=torch.exp(-mean_scale_ratio+0.5*var_scale_sqr_ratio+loc_scale_ratio)return-t1+t2+t3-(0.5*(1+math.log(2*math.pi)))@register_kl(Normal,Laplace)def_kl_normal_laplace(p,q):loc_diff=p.loc-q.locscale_ratio=p.scale/q.scaleloc_diff_scale_ratio=loc_diff/p.scalet1=torch.log(scale_ratio)t2=(math.sqrt(2/math.pi)*p.scale*torch.exp(-0.5*loc_diff_scale_ratio.pow(2)))t3=loc_diff*torch.erf(math.sqrt(0.5)*loc_diff_scale_ratio)return-t1+(t2+t3)/q.scale-(0.5*(1+math.log(0.5*math.pi)))@register_kl(Pareto,Beta)@register_kl(Pareto,ContinuousBernoulli)@register_kl(Pareto,Uniform)def_kl_pareto_infinity(p,q):return_infinite_like(p.scale)@register_kl(Pareto,Exponential)def_kl_pareto_exponential(p,q):scale_rate_prod=p.scale*q.ratet1=(p.alpha/scale_rate_prod).log()t2=p.alpha.reciprocal()t3=p.alpha*scale_rate_prod/(p.alpha-1)result=t1-t2+t3-1result[p.alpha<=1]=infreturnresult@register_kl(Pareto,Gamma)def_kl_pareto_gamma(p,q):common_term=p.scale.log()+p.alpha.reciprocal()t1=p.alpha.log()-common_termt2=q.concentration.lgamma()-q.concentration*q.rate.log()t3=(1-q.concentration)*common_termt4=q.rate*p.alpha*p.scale/(p.alpha-1)result=t1+t2+t3+t4-1result[p.alpha<=1]=infreturnresult# TODO: Add Pareto-Laplace KL Divergence@register_kl(Pareto,Normal)def_kl_pareto_normal(p,q):var_normal=2*q.scale.pow(2)common_term=p.scale/(p.alpha-1)t1=(math.sqrt(2*math.pi)*q.scale*p.alpha/p.scale).log()t2=p.alpha.reciprocal()t3=p.alpha*common_term.pow(2)/(p.alpha-2)t4=(p.alpha*common_term-q.loc).pow(2)result=t1-t2+(t3+t4)/var_normal-1result[p.alpha<=2]=infreturnresult@register_kl(Poisson,Bernoulli)@register_kl(Poisson,Binomial)def_kl_poisson_infinity(p,q):return_infinite_like(p.rate)@register_kl(Uniform,Beta)def_kl_uniform_beta(p,q):common_term=p.high-p.lowt1=torch.log(common_term)t2=((q.concentration1-1)*(_x_log_x(p.high)-_x_log_x(p.low)-common_term)/common_term)t3=((q.concentration0-1)*(_x_log_x(1-p.high)-_x_log_x(1-p.low)+common_term)/common_term)t4=(q.concentration1.lgamma()+q.concentration0.lgamma()-(q.concentration1+q.concentration0).lgamma())result=t3+t4-t1-t2result[(p.high>q.support.upper_bound)|(p.low<q.support.lower_bound)]=infreturnresult@register_kl(Uniform,ContinuousBernoulli)def_kl_uniform_continuous_bernoulli(p,q):result=(-p.entropy()-p.mean*q.logits-torch.log1p(-q.probs)-q._cont_bern_log_norm())returntorch.where(torch.max(torch.ge(p.high,q.support.upper_bound),torch.le(p.low,q.support.lower_bound),),torch.ones_like(result)*inf,result,)@register_kl(Uniform,Exponential)def_kl_uniform_exponetial(p,q):result=q.rate*(p.high+p.low)/2-((p.high-p.low)*q.rate).log()result[p.low<q.support.lower_bound]=infreturnresult@register_kl(Uniform,Gamma)def_kl_uniform_gamma(p,q):common_term=p.high-p.lowt1=common_term.log()t2=q.concentration.lgamma()-q.concentration*q.rate.log()t3=((1-q.concentration)*(_x_log_x(p.high)-_x_log_x(p.low)-common_term)/common_term)t4=q.rate*(p.high+p.low)/2result=-t1+t2+t3+t4result[p.low<q.support.lower_bound]=infreturnresult@register_kl(Uniform,Gumbel)def_kl_uniform_gumbel(p,q):common_term=q.scale/(p.high-p.low)high_loc_diff=(p.high-q.loc)/q.scalelow_loc_diff=(p.low-q.loc)/q.scalet1=common_term.log()+0.5*(high_loc_diff+low_loc_diff)t2=common_term*(torch.exp(-high_loc_diff)-torch.exp(-low_loc_diff))returnt1-t2# TODO: Uniform-Laplace KL Divergence@register_kl(Uniform,Normal)def_kl_uniform_normal(p,q):common_term=p.high-p.lowt1=(math.sqrt(math.pi*2)*q.scale/common_term).log()t2=(common_term).pow(2)/12t3=((p.high+p.low-2*q.loc)/2).pow(2)returnt1+0.5*(t2+t3)/q.scale.pow(2)@register_kl(Uniform,Pareto)def_kl_uniform_pareto(p,q):support_uniform=p.high-p.lowt1=(q.alpha*q.scale.pow(q.alpha)*(support_uniform)).log()t2=(_x_log_x(p.high)-_x_log_x(p.low)-support_uniform)/support_uniformresult=t2*(q.alpha+1)-t1result[p.low<q.support.lower_bound]=infreturnresult@register_kl(Independent,Independent)def_kl_independent_independent(p,q):ifp.reinterpreted_batch_ndims!=q.reinterpreted_batch_ndims:raiseNotImplementedErrorresult=kl_divergence(p.base_dist,q.base_dist)return_sum_rightmost(result,p.reinterpreted_batch_ndims)@register_kl(Cauchy,Cauchy)def_kl_cauchy_cauchy(p,q):# From https://arxiv.org/abs/1905.10965t1=((p.scale+q.scale).pow(2)+(p.loc-q.loc).pow(2)).log()t2=(4*p.scale*q.scale).log()returnt1-t2def_add_kl_info():"""Appends a list of implemented KL functions to the doc for kl_divergence."""rows=["KL divergence is currently implemented for the following distribution pairs:"]forp,qinsorted(_KL_REGISTRY,key=lambdap_q:(p_q[0].__name__,p_q[1].__name__)):rows.append(f"* :class:`~torch.distributions.{p.__name__}` and :class:`~torch.distributions.{q.__name__}`")kl_info="\n\t".join(rows)ifkl_divergence.__doc__:kl_divergence.__doc__+=kl_info# type: ignore[operator]
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.