Shortcuts

# Source code for torch.distributions.mixture_same_family

# mypy: allow-untyped-defs
from typing import Dict

import torch
from torch.distributions import Categorical, constraints
from torch.distributions.distribution import Distribution

__all__ = ["MixtureSameFamily"]

[docs]class MixtureSameFamily(Distribution):
r"""
The MixtureSameFamily distribution implements a (batch of) mixture
distribution where all component are from different parameterizations of
the same distribution type. It is parameterized by a Categorical
"selecting distribution" (over k component) and a component
distribution, i.e., a Distribution with a rightmost batch shape
(equal to [k]) which indexes each (batch of) component.

Examples::

>>> # xdoctest: +SKIP("undefined vars")
>>> # Construct Gaussian Mixture Model in 1D consisting of 5 equally
>>> # weighted normal distributions
>>> mix = D.Categorical(torch.ones(5,))
>>> comp = D.Normal(torch.randn(5,), torch.rand(5,))
>>> gmm = MixtureSameFamily(mix, comp)

>>> # Construct Gaussian Mixture Model in 2D consisting of 5 equally
>>> # weighted bivariate normal distributions
>>> mix = D.Categorical(torch.ones(5,))
>>> comp = D.Independent(D.Normal(
...          torch.randn(5,2), torch.rand(5,2)), 1)
>>> gmm = MixtureSameFamily(mix, comp)

>>> # Construct a batch of 3 Gaussian Mixture Models in 2D each
>>> # consisting of 5 random weighted bivariate normal distributions
>>> mix = D.Categorical(torch.rand(3,5))
>>> comp = D.Independent(D.Normal(
...         torch.randn(3,5,2), torch.rand(3,5,2)), 1)
>>> gmm = MixtureSameFamily(mix, comp)

Args:
mixture_distribution: torch.distributions.Categorical-like
instance. Manages the probability of selecting component.
The number of categories must match the rightmost batch
dimension of the component_distribution. Must have either
scalar batch_shape or batch_shape matching
component_distribution.batch_shape[:-1]
component_distribution: torch.distributions.Distribution-like
instance. Right-most batch dimension indexes component.
"""
arg_constraints: Dict[str, constraints.Constraint] = {}
has_rsample = False

def __init__(
self, mixture_distribution, component_distribution, validate_args=None
):
self._mixture_distribution = mixture_distribution
self._component_distribution = component_distribution

if not isinstance(self._mixture_distribution, Categorical):
raise ValueError(
" The Mixture distribution needs to be an "
" instance of torch.distributions.Categorical"
)

if not isinstance(self._component_distribution, Distribution):
raise ValueError(
"The Component distribution need to be an "
"instance of torch.distributions.Distribution"
)

# Check that batch size matches
mdbs = self._mixture_distribution.batch_shape
cdbs = self._component_distribution.batch_shape[:-1]
for size1, size2 in zip(reversed(mdbs), reversed(cdbs)):
if size1 != 1 and size2 != 1 and size1 != size2:
raise ValueError(
f"mixture_distribution.batch_shape ({mdbs}) is not "
"compatible with component_distribution."
f"batch_shape({cdbs})"
)

# Check that the number of mixture component matches
km = self._mixture_distribution.logits.shape[-1]
kc = self._component_distribution.batch_shape[-1]
if km is not None and kc is not None and km != kc:
raise ValueError(
f"mixture_distribution component ({km}) does not"
" equal component_distribution.batch_shape[-1]"
f" ({kc})"
)
self._num_component = km

event_shape = self._component_distribution.event_shape
self._event_ndims = len(event_shape)
super().__init__(
batch_shape=cdbs, event_shape=event_shape, validate_args=validate_args
)

[docs]    def expand(self, batch_shape, _instance=None):
batch_shape = torch.Size(batch_shape)
batch_shape_comp = batch_shape + (self._num_component,)
new = self._get_checked_instance(MixtureSameFamily, _instance)
new._component_distribution = self._component_distribution.expand(
batch_shape_comp
)
new._mixture_distribution = self._mixture_distribution.expand(batch_shape)
new._num_component = self._num_component
new._event_ndims = self._event_ndims
event_shape = new._component_distribution.event_shape
super(MixtureSameFamily, new).__init__(
batch_shape=batch_shape, event_shape=event_shape, validate_args=False
)
new._validate_args = self._validate_args
return new

@constraints.dependent_property
def support(self):
# FIXME this may have the wrong shape when support contains batched
# parameters
return self._component_distribution.support

@property
def mixture_distribution(self):
return self._mixture_distribution

@property
def component_distribution(self):
return self._component_distribution

@property
def mean(self):
probs * self.component_distribution.mean, dim=-1 - self._event_ndims
)  # [B, E]

@property
def variance(self):
# Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X])
mean_cond_var = torch.sum(
probs * self.component_distribution.variance, dim=-1 - self._event_ndims
)
var_cond_mean = torch.sum(
dim=-1 - self._event_ndims,
)
return mean_cond_var + var_cond_mean

[docs]    def cdf(self, x):
cdf_x = self.component_distribution.cdf(x)
mix_prob = self.mixture_distribution.probs

[docs]    def log_prob(self, x):
if self._validate_args:
self._validate_sample(x)
log_prob_x = self.component_distribution.log_prob(x)  # [S, B, k]
log_mix_prob = torch.log_softmax(
self.mixture_distribution.logits, dim=-1
)  # [B, k]

[docs]    def sample(self, sample_shape=torch.Size()):
sample_len = len(sample_shape)
batch_len = len(self.batch_shape)
gather_dim = sample_len + batch_len
es = self.event_shape

# mixture samples [n, B]
mix_sample = self.mixture_distribution.sample(sample_shape)
mix_shape = mix_sample.shape

# component samples [n, B, k, E]
comp_samples = self.component_distribution.sample(sample_shape)

# Gather along the k dimension
mix_sample_r = mix_sample.reshape(
mix_shape + torch.Size([1] * (len(es) + 1))
)
mix_sample_r = mix_sample_r.repeat(
torch.Size([1] * len(mix_shape)) + torch.Size([1]) + es
)

samples = torch.gather(comp_samples, gather_dim, mix_sample_r)
return samples.squeeze(gather_dim)

return x.unsqueeze(-1 - self._event_ndims)

dist_batch_ndims = len(self.batch_shape)
cat_batch_ndims = len(self.mixture_distribution.batch_shape)
pad_ndims = 0 if cat_batch_ndims == 1 else dist_batch_ndims - cat_batch_ndims
xs = x.shape
x = x.reshape(
xs[:-1]
+ xs[-1:]
+ torch.Size(self._event_ndims * [1])
)
return x

def __repr__(self):
args_string = (
f"\n  {self.mixture_distribution},\n  {self.component_distribution}"
)
return "MixtureSameFamily" + "(" + args_string + ")"


## Docs

Access comprehensive developer documentation for PyTorch

View Docs

## Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials