# Source code for torch.distributions.mixture_same_family

import torch
from torch.distributions.distribution import Distribution
from torch.distributions import Categorical
from torch.distributions import constraints
from typing import Dict

[docs]class MixtureSameFamily(Distribution): r""" The MixtureSameFamily distribution implements a (batch of) mixture distribution where all component are from different parameterizations of the same distribution type. It is parameterized by a Categorical "selecting distribution" (over k component) and a component distribution, i.e., a Distribution with a rightmost batch shape (equal to [k]) which indexes each (batch of) component. Examples:: # Construct Gaussian Mixture Model in 1D consisting of 5 equally # weighted normal distributions >>> mix = D.Categorical(torch.ones(5,)) >>> comp = D.Normal(torch.randn(5,), torch.rand(5,)) >>> gmm = MixtureSameFamily(mix, comp) # Construct Gaussian Mixture Modle in 2D consisting of 5 equally # weighted bivariate normal distributions >>> mix = D.Categorical(torch.ones(5,)) >>> comp = D.Independent(D.Normal( torch.randn(5,2), torch.rand(5,2)), 1) >>> gmm = MixtureSameFamily(mix, comp) # Construct a batch of 3 Gaussian Mixture Models in 2D each # consisting of 5 random weighted bivariate normal distributions >>> mix = D.Categorical(torch.rand(3,5)) >>> comp = D.Independent(D.Normal( torch.randn(3,5,2), torch.rand(3,5,2)), 1) >>> gmm = MixtureSameFamily(mix, comp) Args: mixture_distribution: torch.distributions.Categorical-like instance. Manages the probability of selecting component. The number of categories must match the rightmost batch dimension of the component_distribution. Must have either scalar batch_shape or batch_shape matching component_distribution.batch_shape[:-1] component_distribution: torch.distributions.Distribution-like instance. Right-most batch dimension indexes component. """ arg_constraints: Dict[str, constraints.Constraint] = {} has_rsample = False def __init__(self, mixture_distribution, component_distribution, validate_args=None): self._mixture_distribution = mixture_distribution self._component_distribution = component_distribution if not isinstance(self._mixture_distribution, Categorical): raise ValueError(" The Mixture distribution needs to be an " " instance of torch.distribtutions.Categorical") if not isinstance(self._component_distribution, Distribution): raise ValueError("The Component distribution need to be an " "instance of torch.distributions.Distribution") # Check that batch size matches mdbs = self._mixture_distribution.batch_shape cdbs = self._component_distribution.batch_shape[:-1] for size1, size2 in zip(reversed(mdbs), reversed(cdbs)): if size1 != 1 and size2 != 1 and size1 != size2: raise ValueError("mixture_distribution.batch_shape ({0}) is not " "compatible with component_distribution." "batch_shape({1})".format(mdbs, cdbs)) # Check that the number of mixture component matches km = self._mixture_distribution.logits.shape[-1] kc = self._component_distribution.batch_shape[-1] if km is not None and kc is not None and km != kc: raise ValueError("mixture_distribution component ({0}) does not" " equal component_distribution.batch_shape[-1]" " ({1})".format(km, kc)) self._num_component = km event_shape = self._component_distribution.event_shape self._event_ndims = len(event_shape) super(MixtureSameFamily, self).__init__(batch_shape=cdbs, event_shape=event_shape, validate_args=validate_args)
[docs] def expand(self, batch_shape, _instance=None): batch_shape = torch.Size(batch_shape) batch_shape_comp = batch_shape + (self._num_component,) new = self._get_checked_instance(MixtureSameFamily, _instance) new._component_distribution = \ self._component_distribution.expand(batch_shape_comp) new._mixture_distribution = \ self._mixture_distribution.expand(batch_shape) new._num_component = self._num_component new._event_ndims = self._event_ndims event_shape = new._component_distribution.event_shape super(MixtureSameFamily, new).__init__(batch_shape=batch_shape, event_shape=event_shape, validate_args=False) new._validate_args = self._validate_args return new
@constraints.dependent_property def support(self): # FIXME this may have the wrong shape when support contains batched # parameters return self._component_distribution.support @property def mixture_distribution(self): return self._mixture_distribution @property def component_distribution(self): return self._component_distribution @property def mean(self): probs = self._pad_mixture_dimensions(self.mixture_distribution.probs) return torch.sum(probs * self.component_distribution.mean, dim=-1 - self._event_ndims) # [B, E] @property def variance(self): # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X]) probs = self._pad_mixture_dimensions(self.mixture_distribution.probs) mean_cond_var = torch.sum(probs * self.component_distribution.variance, dim=-1 - self._event_ndims) var_cond_mean = torch.sum(probs * (self.component_distribution.mean - self._pad(self.mean)).pow(2.0), dim=-1 - self._event_ndims) return mean_cond_var + var_cond_mean
[docs] def cdf(self, x): x = self._pad(x) cdf_x = self.component_distribution.cdf(x) mix_prob = self.mixture_distribution.probs return torch.sum(cdf_x * mix_prob, dim=-1)
[docs] def log_prob(self, x): if self._validate_args: self._validate_sample(x) x = self._pad(x) log_prob_x = self.component_distribution.log_prob(x) # [S, B, k] log_mix_prob = torch.log_softmax(self.mixture_distribution.logits, dim=-1) # [B, k] return torch.logsumexp(log_prob_x + log_mix_prob, dim=-1) # [S, B]
[docs] def sample(self, sample_shape=torch.Size()): with torch.no_grad(): sample_len = len(sample_shape) batch_len = len(self.batch_shape) gather_dim = sample_len + batch_len es = self.event_shape # mixture samples [n, B] mix_sample = self.mixture_distribution.sample(sample_shape) mix_shape = mix_sample.shape # component samples [n, B, k, E] comp_samples = self.component_distribution.sample(sample_shape) # Gather along the k dimension mix_sample_r = mix_sample.reshape( mix_shape + torch.Size([1] * (len(es) + 1))) mix_sample_r = mix_sample_r.repeat( torch.Size([1] * len(mix_shape)) + torch.Size([1]) + es) samples = torch.gather(comp_samples, gather_dim, mix_sample_r) return samples.squeeze(gather_dim)
def _pad(self, x): return x.unsqueeze(-1 - self._event_ndims) def _pad_mixture_dimensions(self, x): dist_batch_ndims = self.batch_shape.numel() cat_batch_ndims = self.mixture_distribution.batch_shape.numel() pad_ndims = 0 if cat_batch_ndims == 1 else \ dist_batch_ndims - cat_batch_ndims xs = x.shape x = x.reshape(xs[:-1] + torch.Size(pad_ndims * [1]) + xs[-1:] + torch.Size(self._event_ndims * [1])) return x def __repr__(self): args_string = '\n {},\n {}'.format(self.mixture_distribution, self.component_distribution) return 'MixtureSameFamily' + '(' + args_string + ')'

