[docs]classMixtureSameFamily(Distribution):r""" The `MixtureSameFamily` distribution implements a (batch of) mixture distribution where all component are from different parameterizations of the same distribution type. It is parameterized by a `Categorical` "selecting distribution" (over `k` component) and a component distribution, i.e., a `Distribution` with a rightmost batch shape (equal to `[k]`) which indexes each (batch of) component. Examples:: # Construct Gaussian Mixture Model in 1D consisting of 5 equally # weighted normal distributions >>> mix = D.Categorical(torch.ones(5,)) >>> comp = D.Normal(torch.randn(5,), torch.rand(5,)) >>> gmm = MixtureSameFamily(mix, comp) # Construct Gaussian Mixture Modle in 2D consisting of 5 equally # weighted bivariate normal distributions >>> mix = D.Categorical(torch.ones(5,)) >>> comp = D.Independent(D.Normal( torch.randn(5,2), torch.rand(5,2)), 1) >>> gmm = MixtureSameFamily(mix, comp) # Construct a batch of 3 Gaussian Mixture Models in 2D each # consisting of 5 random weighted bivariate normal distributions >>> mix = D.Categorical(torch.rand(3,5)) >>> comp = D.Independent(D.Normal( torch.randn(3,5,2), torch.rand(3,5,2)), 1) >>> gmm = MixtureSameFamily(mix, comp) Args: mixture_distribution: `torch.distributions.Categorical`-like instance. Manages the probability of selecting component. The number of categories must match the rightmost batch dimension of the `component_distribution`. Must have either scalar `batch_shape` or `batch_shape` matching `component_distribution.batch_shape[:-1]` component_distribution: `torch.distributions.Distribution`-like instance. Right-most batch dimension indexes component. """arg_constraints:Dict[str,constraints.Constraint]={}has_rsample=Falsedef__init__(self,mixture_distribution,component_distribution,validate_args=None):self._mixture_distribution=mixture_distributionself._component_distribution=component_distributionifnotisinstance(self._mixture_distribution,Categorical):raiseValueError(" The Mixture distribution needs to be an "" instance of torch.distribtutions.Categorical")ifnotisinstance(self._component_distribution,Distribution):raiseValueError("The Component distribution need to be an ""instance of torch.distributions.Distribution")# Check that batch size matchesmdbs=self._mixture_distribution.batch_shapecdbs=self._component_distribution.batch_shape[:-1]forsize1,size2inzip(reversed(mdbs),reversed(cdbs)):ifsize1!=1andsize2!=1andsize1!=size2:raiseValueError("`mixture_distribution.batch_shape` ({0}) is not ""compatible with `component_distribution.""batch_shape`({1})".format(mdbs,cdbs))# Check that the number of mixture component matcheskm=self._mixture_distribution.logits.shape[-1]kc=self._component_distribution.batch_shape[-1]ifkmisnotNoneandkcisnotNoneandkm!=kc:raiseValueError("`mixture_distribution component` ({0}) does not"" equal `component_distribution.batch_shape[-1]`"" ({1})".format(km,kc))self._num_component=kmevent_shape=self._component_distribution.event_shapeself._event_ndims=len(event_shape)super(MixtureSameFamily,self).__init__(batch_shape=cdbs,event_shape=event_shape,validate_args=validate_args)
@constraints.dependent_propertydefsupport(self):# FIXME this may have the wrong shape when support contains batched# parametersreturnself._component_distribution.support@propertydefmixture_distribution(self):returnself._mixture_distribution@propertydefcomponent_distribution(self):returnself._component_distribution@propertydefmean(self):probs=self._pad_mixture_dimensions(self.mixture_distribution.probs)returntorch.sum(probs*self.component_distribution.mean,dim=-1-self._event_ndims)# [B, E]@propertydefvariance(self):# Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X])probs=self._pad_mixture_dimensions(self.mixture_distribution.probs)mean_cond_var=torch.sum(probs*self.component_distribution.variance,dim=-1-self._event_ndims)var_cond_mean=torch.sum(probs*(self.component_distribution.mean-self._pad(self.mean)).pow(2.0),dim=-1-self._event_ndims)returnmean_cond_var+var_cond_mean
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.