Probability distributions - torch.distributions

The distributions package contains parameterizable probability distributions and sampling functions.

Policy gradient methods can be implemented using the log_prob() method, when the probability density function is differentiable with respect to its parameters. A basic method is the REINFORCE rule:

\[\Delta\theta = \alpha r \frac{\partial\log p(a|\pi^\theta(s))}{\partial\theta}\]

where \(\theta\) are the parameters, \(\alpha\) is the learning rate, \(r\) is the reward and \(p(a|\pi^\theta(s))\) is the probability of taking action \(a\) in state \(s\) given policy \(\pi^\theta\).

In practice we would sample an action from the output of a network, apply this action in an environment, and then use log_prob to construct an equivalent loss function. Note that we use a negative because optimisers use gradient descent, whilst the rule above assumes gradient ascent. With a categorical policy, the code for implementing REINFORCE would be as follows:

probs = policy_network(state)
# NOTE: this is equivalent to what used to be called multinomial
m = Categorical(probs)
action = m.sample()
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward
loss.backward()

Distribution

class torch.distributions.Distribution(batch_shape=torch.Size([]), event_shape=torch.Size([]))[source]

Distribution is the abstract base class for probability distributions.

batch_shape

Returns the shape over which parameters are batched.

entropy()[source]

Returns entropy of distribution, batched over batch_shape.

Returns:Tensor or Variable of shape batch_shape.
enumerate_support()[source]

Returns tensor containing all values supported by a discrete distribution. The result will enumerate over dimension 0, so the shape of the result will be (cardinality,) + batch_shape + event_shape (where event_shape = () for univariate distributions).

Note that this enumerates over all batched variables in lock-step [[0, 0], [1, 1], …]. To iterate over the full Cartesian product use itertools.product(m.enumerate_support()).

Returns:Variable or Tensor iterating over dimension 0.
event_shape

Returns the shape of a single sample (without batching).

log_prob(value)[source]

Returns the log of the probability density/mass function evaluated at value.

Parameters:value (Tensor or Variable) –
params

Returns a dictionary from param names to Constraint objects that should be satisfied by each parameter of this distribution. For distributions with multiple parameterization, only one complete set of parameters should be specified in .params.

rsample(sample_shape=torch.Size([]))[source]

Generates a sample_shape shaped reparameterized sample or sample_shape shaped batch of reparameterized samples if the distribution parameters are batched.

sample(sample_shape=torch.Size([]))[source]

Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.

sample_n(n)[source]

Generates n samples or n batches of samples if the distribution parameters are batched.

support

Returns a Constraint object representing this distribution’s support.

Bernoulli

class torch.distributions.Bernoulli(probs=None, logits=None)[source]

Creates a Bernoulli distribution parameterized by probs or logits.

Samples are binary (0 or 1). They take the value 1 with probability p and 0 with probability 1 - p.

Example:

>>> m = Bernoulli(torch.Tensor([0.3]))
>>> m.sample()  # 30% chance 1; 70% chance 0
 0.0
[torch.FloatTensor of size 1]
Parameters:

Beta

class torch.distributions.Beta(concentration1, concentration0)[source]

Beta distribution parameterized by concentration1 and concentration0.

Example:

>>> m = Beta(torch.Tensor([0.5]), torch.Tensor([0.5]))
>>> m.sample()  # Beta distributed with concentration concentration1 and concentration0
 0.1046
[torch.FloatTensor of size 1]
Parameters:
  • concentration1 (float or Tensor or Variable) – 1st concentration parameter of the distribution (often referred to as alpha)
  • concentration0 (float or Tensor or Variable) – 2nd concentration parameter of the distribution (often referred to as beta)

Binomial

class torch.distributions.Binomial(total_count=1, probs=None, logits=None)[source]

Creates a Binomial distribution parameterized by total_count and either probs or logits (but not both).

  • Requires a single shared total_count for all parameters and samples.

Example:

>>> m = Binomial(100, torch.Tensor([0 , .2, .8, 1]))
>>> x = m.sample()
 0
 22
 71
 100
[torch.FloatTensor of size 4]]
Parameters:

Categorical

class torch.distributions.Categorical(probs=None, logits=None)[source]

Creates a categorical distribution parameterized by either probs or logits (but not both).

Note

It is equivalent to the distribution that torch.multinomial() samples from.

Samples are integers from 0 … K-1 where K is probs.size(-1).

If probs is 1D with length-K, each element is the relative probability of sampling the class at that index.

If probs is 2D, it is treated as a batch of probability vectors.

See also: torch.multinomial()

Example:

>>> m = Categorical(torch.Tensor([ 0.25, 0.25, 0.25, 0.25 ]))
>>> m.sample()  # equal probability of 0, 1, 2, 3
 3
[torch.LongTensor of size 1]
Parameters:

Cauchy

class torch.distributions.Cauchy(loc, scale)[source]

Samples from a Cauchy (Lorentz) distribution. The distribution of the ratio of independent normally distributed random variables with means 0 follows a Cauchy distribution.

Example:

>>> m = Cauchy(torch.Tensor([0.0]), torch.Tensor([1.0]))
>>> m.sample()  # sample from a Cauchy distribution with loc=0 and scale=1
 2.3214
[torch.FloatTensor of size 1]
Parameters:

Chi2

class torch.distributions.Chi2(df)[source]

Creates a Chi2 distribution parameterized by shape parameter df. This is exactly equivalent to Gamma(alpha=0.5*df, beta=0.5)

Example:

>>> m = Chi2(torch.Tensor([1.0]))
>>> m.sample()  # Chi2 distributed with shape df=1
 0.1046
[torch.FloatTensor of size 1]
Parameters:df (float or Tensor or Variable) – shape parameter of the distribution

Dirichlet

class torch.distributions.Dirichlet(concentration)[source]

Creates a Dirichlet distribution parameterized by concentration concentration.

Example:

>>> m = Dirichlet(torch.Tensor([0.5, 0.5]))
>>> m.sample()  # Dirichlet distributed with concentrarion concentration
 0.1046
 0.8954
[torch.FloatTensor of size 2]
Parameters:concentration (Tensor or Variable) – concentration parameter of the distribution (often referred to as alpha)

Exponential

class torch.distributions.Exponential(rate)[source]

Creates a Exponential distribution parameterized by rate.

Example:

>>> m = Exponential(torch.Tensor([1.0]))
>>> m.sample()  # Exponential distributed with rate=1
 0.1046
[torch.FloatTensor of size 1]
Parameters:rate (float or Tensor or Variable) – rate = 1 / scale of the distribution

FisherSnedecor

class torch.distributions.FisherSnedecor(df1, df2)[source]

Creates a Fisher-Snedecor distribution parameterized by df1 and df2.

Example:

>>> m = FisherSnedecor(torch.Tensor([1.0]), torch.Tensor([2.0]))
>>> m.sample()  # Fisher-Snedecor-distributed with df1=1 and df2=2
 0.2453
[torch.FloatTensor of size 1]
Parameters:

Gamma

class torch.distributions.Gamma(concentration, rate)[source]

Creates a Gamma distribution parameterized by shape concentration and rate.

Example:

>>> m = Gamma(torch.Tensor([1.0]), torch.Tensor([1.0]))
>>> m.sample()  # Gamma distributed with concentration=1 and rate=1
 0.1046
[torch.FloatTensor of size 1]
Parameters:
  • concentration (float or Tensor or Variable) – shape parameter of the distribution (often referred to as alpha)
  • rate (float or Tensor or Variable) – rate = 1 / scale of the distribution (often referred to as beta)

Geometric

class torch.distributions.Geometric(probs=None, logits=None)[source]

Creates a Geometric distribution parameterized by probs, where probs is the probability of success of Bernoulli trials. It represents the probability that in k + 1 Bernoulli trials, the first k trials failed, before seeing a success.

Samples are non-negative integers [0, inf).

Example:

>>> m = Geometric(torch.Tensor([0.3]))
>>> m.sample()  # underlying Bernoulli has 30% chance 1; 70% chance 0
 2
[torch.FloatTensor of size 1]
Parameters:
  • probs (Number, Tensor or Variable) – the probabilty of sampling 1. Must be in range (0, 1]
  • logits (Number, Tensor or Variable) – the log-odds of sampling 1.

Gumbel

class torch.distributions.Gumbel(loc, scale)[source]

Samples from a Gumbel Distribution.

Examples:

>>> m = Gumbel(torch.Tensor([1.0]), torch.Tensor([2.0]))
>>> m.sample()  # sample from Gumbel distribution with loc=1, scale=2
 1.0124
[torch.FloatTensor of size 1]
Parameters:

Laplace

class torch.distributions.Laplace(loc, scale)[source]

Creates a Laplace distribution parameterized by loc and ‘scale’.

Example:

>>> m = Laplace(torch.Tensor([0.0]), torch.Tensor([1.0]))
>>> m.sample()  # Laplace distributed with loc=0, scale=1
 0.1046
[torch.FloatTensor of size 1]
Parameters:

Normal

class torch.distributions.Normal(loc, scale)[source]

Creates a normal (also called Gaussian) distribution parameterized by loc and scale.

Example:

>>> m = Normal(torch.Tensor([0.0]), torch.Tensor([1.0]))
>>> m.sample()  # normally distributed with loc=0 and scale=1
 0.1046
[torch.FloatTensor of size 1]
Parameters:
  • loc (float or Tensor or Variable) – mean of the distribution (often referred to as mu)
  • scale (float or Tensor or Variable) – standard deviation of the distribution (often referred to as sigma)

OneHotCategorical

class torch.distributions.OneHotCategorical(probs=None, logits=None)[source]

Creates a one-hot categorical distribution parameterized by probs.

Samples are one-hot coded vectors of size probs.size(-1).

See also: torch.distributions.Categorical()

Example:

>>> m = OneHotCategorical(torch.Tensor([ 0.25, 0.25, 0.25, 0.25 ]))
>>> m.sample()  # equal probability of 0, 1, 2, 3
 0
 0
 1
 0
[torch.FloatTensor of size 4]
Parameters:probs (Tensor or Variable) – event probabilities

Pareto

class torch.distributions.Pareto(scale, alpha)[source]

Samples from a Pareto Type 1 distribution.

Example:

>>> m = Pareto(torch.Tensor([1.0]), torch.Tensor([1.0]))
>>> m.sample()  # sample from a Pareto distribution with scale=1 and alpha=1
 1.5623
[torch.FloatTensor of size 1]
Parameters:

StudentT

class torch.distributions.StudentT(df, loc=0.0, scale=1.0)[source]

Creates a Student’s t-distribution parameterized by df.

Example:

>>> m = StudentT(torch.Tensor([2.0]))
>>> m.sample()  # Student's t-distributed with degrees of freedom=2
 0.1046
[torch.FloatTensor of size 1]
Parameters:df (float or Tensor or Variable) – degrees of freedom

Uniform

class torch.distributions.Uniform(low, high)[source]

Generates uniformly distributed random samples from the half-open interval [low, high).

Example:

>>> m = Uniform(torch.Tensor([0.0]), torch.Tensor([5.0]))
>>> m.sample()  # uniformly distributed in the range [0.0, 5.0)
 2.3418
[torch.FloatTensor of size 1]
Parameters:

KL Divergence

torch.distributions.kl.kl_divergence(p, q)[source]

Compute Kullback-Leibler divergence \(KL(p \| q)\) between two distributions.

\[KL(p \| q) = \int p(x) \log\frac {p(x)} {q(x)} \,dx\]
Parameters:
Returns:

A batch of KL divergences of shape batch_shape.

Return type:

Variable or Tensor

Raises:

NotImplementedError – If the distribution types have not been registered via register_kl().

torch.distributions.kl.register_kl(type_p, type_q)[source]

Decorator to register a pairwise function with kl_divergence(). Usage:

@register_kl(Normal, Normal)
def kl_normal_normal(p, q):
    # insert implementation here

Lookup returns the most specific (type,type) match ordered by subclass. If the match is ambiguous, a RuntimeWarning is raised. For example to resolve the ambiguous situation:

@register_kl(BaseP, DerivedQ)
def kl_version1(p, q): ...
@register_kl(DerivedP, BaseQ)
def kl_version2(p, q): ...

you should register a third most-specific implementation, e.g.:

register_kl(DerivedP, DerivedQ)(kl_version1)  # Break the tie.
Parameters: