Source code for torch.distributions.lowrank_multivariate_normal
importmathimporttorchfromtorch.distributionsimportconstraintsfromtorch.distributions.distributionimportDistributionfromtorch.distributions.multivariate_normalimport_batch_mahalanobis,_batch_mvfromtorch.distributions.utilsimport_standard_normal,lazy_property__all__=["LowRankMultivariateNormal"]def_batch_capacitance_tril(W,D):r""" Computes Cholesky of :math:`I + W.T @ inv(D) @ W` for a batch of matrices :math:`W` and a batch of vectors :math:`D`. """m=W.size(-1)Wt_Dinv=W.mT/D.unsqueeze(-2)K=torch.matmul(Wt_Dinv,W).contiguous()K.view(-1,m*m)[:,::m+1]+=1# add identity matrix to Kreturntorch.linalg.cholesky(K)def_batch_lowrank_logdet(W,D,capacitance_tril):r""" Uses "matrix determinant lemma":: log|W @ W.T + D| = log|C| + log|D|, where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the log determinant. """return2*capacitance_tril.diagonal(dim1=-2,dim2=-1).log().sum(-1)+D.log().sum(-1)def_batch_lowrank_mahalanobis(W,D,x,capacitance_tril):r""" Uses "Woodbury matrix identity":: inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D), where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the squared Mahalanobis distance :math:`x.T @ inv(W @ W.T + D) @ x`. """Wt_Dinv=W.mT/D.unsqueeze(-2)Wt_Dinv_x=_batch_mv(Wt_Dinv,x)mahalanobis_term1=(x.pow(2)/D).sum(-1)mahalanobis_term2=_batch_mahalanobis(capacitance_tril,Wt_Dinv_x)returnmahalanobis_term1-mahalanobis_term2
[docs]classLowRankMultivariateNormal(Distribution):r""" Creates a multivariate normal distribution with covariance matrix having a low-rank form parameterized by :attr:`cov_factor` and :attr:`cov_diag`:: covariance_matrix = cov_factor @ cov_factor.T + cov_diag Example: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) >>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> m = LowRankMultivariateNormal(torch.zeros(2), torch.tensor([[1.], [0.]]), torch.ones(2)) >>> m.sample() # normally distributed with mean=`[0,0]`, cov_factor=`[[1],[0]]`, cov_diag=`[1,1]` tensor([-0.2102, -0.5429]) Args: loc (Tensor): mean of the distribution with shape `batch_shape + event_shape` cov_factor (Tensor): factor part of low-rank form of covariance matrix with shape `batch_shape + event_shape + (rank,)` cov_diag (Tensor): diagonal part of low-rank form of covariance matrix with shape `batch_shape + event_shape` Note: The computation for determinant and inverse of covariance matrix is avoided when `cov_factor.shape[1] << cov_factor.shape[0]` thanks to `Woodbury matrix identity <https://en.wikipedia.org/wiki/Woodbury_matrix_identity>`_ and `matrix determinant lemma <https://en.wikipedia.org/wiki/Matrix_determinant_lemma>`_. Thanks to these formulas, we just need to compute the determinant and inverse of the small size "capacitance" matrix:: capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor """arg_constraints={"loc":constraints.real_vector,"cov_factor":constraints.independent(constraints.real,2),"cov_diag":constraints.independent(constraints.positive,1),}support=constraints.real_vectorhas_rsample=Truedef__init__(self,loc,cov_factor,cov_diag,validate_args=None):ifloc.dim()<1:raiseValueError("loc must be at least one-dimensional.")event_shape=loc.shape[-1:]ifcov_factor.dim()<2:raiseValueError("cov_factor must be at least two-dimensional, ""with optional leading batch dimensions")ifcov_factor.shape[-2:-1]!=event_shape:raiseValueError(f"cov_factor must be a batch of matrices with shape {event_shape[0]} x m")ifcov_diag.shape[-1:]!=event_shape:raiseValueError(f"cov_diag must be a batch of vectors with shape {event_shape}")loc_=loc.unsqueeze(-1)cov_diag_=cov_diag.unsqueeze(-1)try:loc_,self.cov_factor,cov_diag_=torch.broadcast_tensors(loc_,cov_factor,cov_diag_)exceptRuntimeErrorase:raiseValueError(f"Incompatible batch shapes: loc {loc.shape}, cov_factor {cov_factor.shape}, cov_diag {cov_diag.shape}")fromeself.loc=loc_[...,0]self.cov_diag=cov_diag_[...,0]batch_shape=self.loc.shape[:-1]self._unbroadcasted_cov_factor=cov_factorself._unbroadcasted_cov_diag=cov_diagself._capacitance_tril=_batch_capacitance_tril(cov_factor,cov_diag)super().__init__(batch_shape,event_shape,validate_args=validate_args)
@propertydefmean(self):returnself.loc@propertydefmode(self):returnself.loc@lazy_propertydefvariance(self):return(self._unbroadcasted_cov_factor.pow(2).sum(-1)+self._unbroadcasted_cov_diag).expand(self._batch_shape+self._event_shape)@lazy_propertydefscale_tril(self):# The following identity is used to increase the numerically computation stability# for Cholesky decomposition (see http://www.gaussianprocess.org/gpml/, Section 3.4.3):# W @ W.T + D = D1/2 @ (I + D-1/2 @ W @ W.T @ D-1/2) @ D1/2# The matrix "I + D-1/2 @ W @ W.T @ D-1/2" has eigenvalues bounded from below by 1,# hence it is well-conditioned and safe to take Cholesky decomposition.n=self._event_shape[0]cov_diag_sqrt_unsqueeze=self._unbroadcasted_cov_diag.sqrt().unsqueeze(-1)Dinvsqrt_W=self._unbroadcasted_cov_factor/cov_diag_sqrt_unsqueezeK=torch.matmul(Dinvsqrt_W,Dinvsqrt_W.mT).contiguous()K.view(-1,n*n)[:,::n+1]+=1# add identity matrix to Kscale_tril=cov_diag_sqrt_unsqueeze*torch.linalg.cholesky(K)returnscale_tril.expand(self._batch_shape+self._event_shape+self._event_shape)@lazy_propertydefcovariance_matrix(self):covariance_matrix=torch.matmul(self._unbroadcasted_cov_factor,self._unbroadcasted_cov_factor.mT)+torch.diag_embed(self._unbroadcasted_cov_diag)returncovariance_matrix.expand(self._batch_shape+self._event_shape+self._event_shape)@lazy_propertydefprecision_matrix(self):# We use "Woodbury matrix identity" to take advantage of low rank form::# inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D)# where :math:`C` is the capacitance matrix.Wt_Dinv=(self._unbroadcasted_cov_factor.mT/self._unbroadcasted_cov_diag.unsqueeze(-2))A=torch.linalg.solve_triangular(self._capacitance_tril,Wt_Dinv,upper=False)precision_matrix=(torch.diag_embed(self._unbroadcasted_cov_diag.reciprocal())-A.mT@A)returnprecision_matrix.expand(self._batch_shape+self._event_shape+self._event_shape)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.