Shortcuts

# Source code for torch.distributions.exp_family

import torch
from torch.distributions.distribution import Distribution

__all__ = ['ExponentialFamily']

[docs]class ExponentialFamily(Distribution): r""" ExponentialFamily is the abstract base class for probability distributions belonging to an exponential family, whose probability mass/density function has the form is defined below .. math:: p_{F}(x; \theta) = \exp(\langle t(x), \theta\rangle - F(\theta) + k(x)) where :math:\theta denotes the natural parameters, :math:t(x) denotes the sufficient statistic, :math:F(\theta) is the log normalizer function for a given family and :math:k(x) is the carrier measure. Note: This class is an intermediary between the Distribution class and distributions which belong to an exponential family mainly to check the correctness of the .entropy() and analytic KL divergence methods. We use this class to compute the entropy and KL divergence using the AD framework and Bregman divergences (courtesy of: Frank Nielsen and Richard Nock, Entropies and Cross-entropies of Exponential Families). """ @property def _natural_params(self): """ Abstract method for natural parameters. Returns a tuple of Tensors based on the distribution """ raise NotImplementedError def _log_normalizer(self, *natural_params): """ Abstract method for log normalizer function. Returns a log normalizer based on the distribution and input """ raise NotImplementedError @property def _mean_carrier_measure(self): """ Abstract method for expected carrier measure, which is required for computing entropy. """ raise NotImplementedError
[docs] def entropy(self): """ Method to compute the entropy using Bregman divergence of the log normalizer. """ result = -self._mean_carrier_measure nparams = [p.detach().requires_grad_() for p in self._natural_params] lg_normal = self._log_normalizer(*nparams) gradients = torch.autograd.grad(lg_normal.sum(), nparams, create_graph=True) result += lg_normal for np, g in zip(nparams, gradients): result -= (np * g).reshape(self._batch_shape + (-1,)).sum(-1) return result

## Docs

Access comprehensive developer documentation for PyTorch

View Docs

## Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials