Shortcuts

Source code for torch.distributions.dirichlet

import torch
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.distributions import constraints
from torch.distributions.exp_family import ExponentialFamily

__all__ = ['Dirichlet']

# This helper is exposed for testing.
def _Dirichlet_backward(x, concentration, grad_output):
    total = concentration.sum(-1, True).expand_as(concentration)
    grad = torch._dirichlet_grad(x, concentration, total)
    return grad * (grad_output - (x * grad_output).sum(-1, True))


class _Dirichlet(Function):
    @staticmethod
    def forward(ctx, concentration):
        x = torch._sample_dirichlet(concentration)
        ctx.save_for_backward(x, concentration)
        return x

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_output):
        x, concentration = ctx.saved_tensors
        return _Dirichlet_backward(x, concentration, grad_output)


[docs]class Dirichlet(ExponentialFamily): r""" Creates a Dirichlet distribution parameterized by concentration :attr:`concentration`. Example:: >>> # xdoctest: +IGNORE_WANT("non-deterinistic") >>> m = Dirichlet(torch.tensor([0.5, 0.5])) >>> m.sample() # Dirichlet distributed with concentration [0.5, 0.5] tensor([ 0.1046, 0.8954]) Args: concentration (Tensor): concentration parameter of the distribution (often referred to as alpha) """ arg_constraints = {'concentration': constraints.independent(constraints.positive, 1)} support = constraints.simplex has_rsample = True def __init__(self, concentration, validate_args=None): if concentration.dim() < 1: raise ValueError("`concentration` parameter must be at least one-dimensional.") self.concentration = concentration batch_shape, event_shape = concentration.shape[:-1], concentration.shape[-1:] super().__init__(batch_shape, event_shape, validate_args=validate_args)
[docs] def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(Dirichlet, _instance) batch_shape = torch.Size(batch_shape) new.concentration = self.concentration.expand(batch_shape + self.event_shape) super(Dirichlet, new).__init__(batch_shape, self.event_shape, validate_args=False) new._validate_args = self._validate_args return new
[docs] def rsample(self, sample_shape=()): shape = self._extended_shape(sample_shape) concentration = self.concentration.expand(shape) return _Dirichlet.apply(concentration)
[docs] def log_prob(self, value): if self._validate_args: self._validate_sample(value) return ((torch.log(value) * (self.concentration - 1.0)).sum(-1) + torch.lgamma(self.concentration.sum(-1)) - torch.lgamma(self.concentration).sum(-1))
@property def mean(self): return self.concentration / self.concentration.sum(-1, True) @property def mode(self): concentrationm1 = (self.concentration - 1).clamp(min=0.) mode = concentrationm1 / concentrationm1.sum(-1, True) mask = (self.concentration < 1).all(axis=-1) mode[mask] = torch.nn.functional.one_hot(mode[mask].argmax(axis=-1), concentrationm1.shape[-1]).to(mode) return mode @property def variance(self): con0 = self.concentration.sum(-1, True) return self.concentration * (con0 - self.concentration) / (con0.pow(2) * (con0 + 1))
[docs] def entropy(self): k = self.concentration.size(-1) a0 = self.concentration.sum(-1) return (torch.lgamma(self.concentration).sum(-1) - torch.lgamma(a0) - (k - a0) * torch.digamma(a0) - ((self.concentration - 1.0) * torch.digamma(self.concentration)).sum(-1))
@property def _natural_params(self): return (self.concentration, ) def _log_normalizer(self, x): return x.lgamma().sum(-1) - torch.lgamma(x.sum(-1))

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources