Source code for torch.distributions.multivariate_normal
importmathimporttorchfromtorch.distributionsimportconstraintsfromtorch.distributions.distributionimportDistributionfromtorch.distributions.utilsimport_standard_normal,lazy_property__all__=["MultivariateNormal"]def_batch_mv(bmat,bvec):r""" Performs a batched matrix-vector product, with compatible but different batch shapes. This function takes as input `bmat`, containing :math:`n \times n` matrices, and `bvec`, containing length :math:`n` vectors. Both `bmat` and `bvec` may have any number of leading dimensions, which correspond to a batch shape. They are not necessarily assumed to have the same batch shape, just ones which can be broadcasted. """returntorch.matmul(bmat,bvec.unsqueeze(-1)).squeeze(-1)def_batch_mahalanobis(bL,bx):r""" Computes the squared Mahalanobis distance :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}` for a factored :math:`\mathbf{M} = \mathbf{L}\mathbf{L}^\top`. Accepts batches for both bL and bx. They are not necessarily assumed to have the same batch shape, but `bL` one should be able to broadcasted to `bx` one. """n=bx.size(-1)bx_batch_shape=bx.shape[:-1]# Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),# we are going to make bx have shape (..., 1, j, i, 1, n) to apply batched tri.solvebx_batch_dims=len(bx_batch_shape)bL_batch_dims=bL.dim()-2outer_batch_dims=bx_batch_dims-bL_batch_dimsold_batch_dims=outer_batch_dims+bL_batch_dimsnew_batch_dims=outer_batch_dims+2*bL_batch_dims# Reshape bx with the shape (..., 1, i, j, 1, n)bx_new_shape=bx.shape[:outer_batch_dims]forsL,sxinzip(bL.shape[:-2],bx.shape[outer_batch_dims:-1]):bx_new_shape+=(sx//sL,sL)bx_new_shape+=(n,)bx=bx.reshape(bx_new_shape)# Permute bx to make it have shape (..., 1, j, i, 1, n)permute_dims=(list(range(outer_batch_dims))+list(range(outer_batch_dims,new_batch_dims,2))+list(range(outer_batch_dims+1,new_batch_dims,2))+[new_batch_dims])bx=bx.permute(permute_dims)flat_L=bL.reshape(-1,n,n)# shape = b x n x nflat_x=bx.reshape(-1,flat_L.size(0),n)# shape = c x b x nflat_x_swap=flat_x.permute(1,2,0)# shape = b x n x cM_swap=(torch.linalg.solve_triangular(flat_L,flat_x_swap,upper=False).pow(2).sum(-2))# shape = b x cM=M_swap.t()# shape = c x b# Now we revert the above reshape and permute operators.permuted_M=M.reshape(bx.shape[:-1])# shape = (..., 1, j, i, 1)permute_inv_dims=list(range(outer_batch_dims))foriinrange(bL_batch_dims):permute_inv_dims+=[outer_batch_dims+i,old_batch_dims+i]reshaped_M=permuted_M.permute(permute_inv_dims)# shape = (..., 1, i, j, 1)returnreshaped_M.reshape(bx_batch_shape)def_precision_to_scale_tril(P):# Ref: https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_trilLf=torch.linalg.cholesky(torch.flip(P,(-2,-1)))L_inv=torch.transpose(torch.flip(Lf,(-2,-1)),-2,-1)Id=torch.eye(P.shape[-1],dtype=P.dtype,device=P.device)L=torch.linalg.solve_triangular(L_inv,Id,upper=False)returnL
[docs]classMultivariateNormal(Distribution):r""" Creates a multivariate normal (also called Gaussian) distribution parameterized by a mean vector and a covariance matrix. The multivariate normal distribution can be parameterized either in terms of a positive definite covariance matrix :math:`\mathbf{\Sigma}` or a positive definite precision matrix :math:`\mathbf{\Sigma}^{-1}` or a lower-triangular matrix :math:`\mathbf{L}` with positive-valued diagonal entries, such that :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`. This triangular matrix can be obtained via e.g. Cholesky decomposition of the covariance. Example: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) >>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> m = MultivariateNormal(torch.zeros(2), torch.eye(2)) >>> m.sample() # normally distributed with mean=`[0,0]` and covariance_matrix=`I` tensor([-0.2102, -0.5429]) Args: loc (Tensor): mean of the distribution covariance_matrix (Tensor): positive-definite covariance matrix precision_matrix (Tensor): positive-definite precision matrix scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal Note: Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or :attr:`scale_tril` can be specified. Using :attr:`scale_tril` will be more efficient: all computations internally are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or :attr:`precision_matrix` is passed instead, it is only used to compute the corresponding lower triangular matrices using a Cholesky decomposition. """arg_constraints={"loc":constraints.real_vector,"covariance_matrix":constraints.positive_definite,"precision_matrix":constraints.positive_definite,"scale_tril":constraints.lower_cholesky,}support=constraints.real_vectorhas_rsample=Truedef__init__(self,loc,covariance_matrix=None,precision_matrix=None,scale_tril=None,validate_args=None,):ifloc.dim()<1:raiseValueError("loc must be at least one-dimensional.")if(covariance_matrixisnotNone)+(scale_trilisnotNone)+(precision_matrixisnotNone)!=1:raiseValueError("Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified.")ifscale_trilisnotNone:ifscale_tril.dim()<2:raiseValueError("scale_tril matrix must be at least two-dimensional, ""with optional leading batch dimensions")batch_shape=torch.broadcast_shapes(scale_tril.shape[:-2],loc.shape[:-1])self.scale_tril=scale_tril.expand(batch_shape+(-1,-1))elifcovariance_matrixisnotNone:ifcovariance_matrix.dim()<2:raiseValueError("covariance_matrix must be at least two-dimensional, ""with optional leading batch dimensions")batch_shape=torch.broadcast_shapes(covariance_matrix.shape[:-2],loc.shape[:-1])self.covariance_matrix=covariance_matrix.expand(batch_shape+(-1,-1))else:ifprecision_matrix.dim()<2:raiseValueError("precision_matrix must be at least two-dimensional, ""with optional leading batch dimensions")batch_shape=torch.broadcast_shapes(precision_matrix.shape[:-2],loc.shape[:-1])self.precision_matrix=precision_matrix.expand(batch_shape+(-1,-1))self.loc=loc.expand(batch_shape+(-1,))event_shape=self.loc.shape[-1:]super().__init__(batch_shape,event_shape,validate_args=validate_args)ifscale_trilisnotNone:self._unbroadcasted_scale_tril=scale_trilelifcovariance_matrixisnotNone:self._unbroadcasted_scale_tril=torch.linalg.cholesky(covariance_matrix)else:# precision_matrix is not Noneself._unbroadcasted_scale_tril=_precision_to_scale_tril(precision_matrix)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.