Source code for torch.distributions

r"""
The ``distributions`` package contains parameterizable probability distributions
and sampling functions.

Policy gradient methods can be implemented using the
:meth:`~torch.distributions.Distribution.log_prob` method, when the probability
density function is differentiable with respect to its parameters. A basic
method is the REINFORCE rule:

.. math::

    \Delta\theta  = \alpha r \frac{\partial\log p(a|\pi^\theta(s))}{\partial\theta}

where :math:`\theta` are the parameters, :math:`\alpha` is the learning rate,
:math:`r` is the reward and :math:`p(a|\pi^\theta(s))` is the probability of
taking action :math:`a` in state :math:`s` given policy :math:`\pi^\theta`.

In practice we would sample an action from the output of a network, apply this
action in an environment, and then use ``log_prob`` to construct an equivalent
loss function. Note that we use a negative because optimisers use gradient
descent, whilst the rule above assumes gradient ascent. With a categorical
policy, the code for implementing REINFORCE would be as follows::

    probs = policy_network(state)
    # NOTE: this is equivalent to what used to be called multinomial
    m = Categorical(probs)
    action = m.sample()
    next_state, reward = env.step(action)
    loss = -m.log_prob(action) * reward
    loss.backward()
"""
import math
from numbers import Number
import torch


__all__ = ['Distribution', 'Bernoulli', 'Categorical', 'Normal']


[docs]class Distribution(object): r""" Distribution is the abstract base class for probability distributions. """
[docs] def sample(self): """ Generates a single sample or single batch of samples if the distribution parameters are batched. """ raise NotImplementedError
[docs] def sample_n(self, n): """ Generates n samples or n batches of samples if the distribution parameters are batched. """ raise NotImplementedError
[docs] def log_prob(self, value): """ Returns the log of the probability density/mass function evaluated at `value`. Args: value (Tensor or Variable): """ raise NotImplementedError
[docs]class Bernoulli(Distribution): r""" Creates a Bernoulli distribution parameterized by `probs`. Samples are binary (0 or 1). They take the value `1` with probability `p` and `0` with probability `1 - p`. Example:: >>> m = Bernoulli(torch.Tensor([0.3])) >>> m.sample() # 30% chance 1; 70% chance 0 0.0 [torch.FloatTensor of size 1] Args: probs (Tensor or Variable): the probabilty of sampling `1` """ def __init__(self, probs): self.probs = probs def sample(self): return torch.bernoulli(self.probs) def sample_n(self, n): return torch.bernoulli(self.probs.expand(n, *self.probs.size())) def log_prob(self, value): # compute the log probabilities for 0 and 1 log_pmf = (torch.stack([1 - self.probs, self.probs])).log() # evaluate using the values return log_pmf.gather(0, value.unsqueeze(0).long()).squeeze(0)
[docs]class Categorical(Distribution): r""" Creates a categorical distribution parameterized by `probs`. .. note:: It is equivalent to the distribution that ``multinomial()`` samples from. Samples are integers from `0 ... K-1` where `K` is probs.size(-1). If `probs` is 1D with length-`K`, each element is the relative probability of sampling the class at that index. If `probs` is 2D, it is treated as a batch of probability vectors. See also: :func:`torch.multinomial` Example:: >>> m = Categorical(torch.Tensor([ 0.25, 0.25, 0.25, 0.25 ])) >>> m.sample() # equal probability of 0, 1, 2, 3 3 [torch.LongTensor of size 1] Args: probs (Tensor or Variable): event probabilities """ def __init__(self, probs): if probs.dim() != 1 and probs.dim() != 2: # TODO: treat higher dimensions as part of the batch raise ValueError("probs must be 1D or 2D") self.probs = probs def sample(self): return torch.multinomial(self.probs, 1, True).squeeze(-1) def sample_n(self, n): if n == 1: return self.sample().expand(1, 1) else: return torch.multinomial(self.probs, n, True).t() def log_prob(self, value): p = self.probs / self.probs.sum(-1, keepdim=True) if value.dim() == 1 and self.probs.dim() == 1: # special handling until we have 0-dim tensor support return p.gather(-1, value).log() return p.gather(-1, value.unsqueeze(-1)).squeeze(-1).log()
[docs]class Normal(Distribution): r""" Creates a normal (also called Gaussian) distribution parameterized by `mean` and `std`. Example:: >>> m = Normal(torch.Tensor([0.0]), torch.Tensor([1.0])) >>> m.sample() # normally distributed with mean=0 and stddev=1 0.1046 [torch.FloatTensor of size 1] Args: mean (float or Tensor or Variable): mean of the distribution std (float or Tensor or Variable): standard deviation of the distribution """ def __init__(self, mean, std): self.mean = mean self.std = std def sample(self): return torch.normal(self.mean, self.std) def sample_n(self, n): # cleanly expand float or Tensor or Variable parameters def expand(v): if isinstance(v, Number): return torch.Tensor([v]).expand(n, 1) else: return v.expand(n, *v.size()) return torch.normal(expand(self.mean), expand(self.std)) def log_prob(self, value): # compute the variance var = (self.std ** 2) log_std = math.log(self.std) if isinstance(self.std, Number) else self.std.log() return -((value - self.mean) ** 2) / (2 * var) - log_std - math.log(math.sqrt(2 * math.pi))