Probability distributions - torch.distributions¶
The distributions
package contains parameterizable probability distributions
and sampling functions. This allows the construction of stochastic computation
graphs and stochastic gradient estimators for optimization.
It is not possible to directly backpropagate through random samples. However, there are two main methods for creating surrogate functions that can be backpropagated through. These are the score function estimator/likelihood ratio estimator/REINFORCE and the pathwise derivative estimator. REINFORCE is commonly seen as the basis for policy gradient methods in reinforcement learning, and the pathwise derivative estimator is commonly seen in the reparameterization trick in variational autoencoders. Whilst the score function only requires the value of samples f(x), the pathwise derivative requires the derivative f′(x). The next sections discuss these two in a reinforcement learning example. For more details see Gradient Estimation Using Stochastic Computation Graphs .
Score function¶
When the probability density function is differentiable with respect to its
parameters, we only need sample()
and
log_prob()
to implement REINFORCE:
where θ are the parameters, α is the learning rate, r is the reward and p(a|πθ(s)) is the probability of taking action a in state s given policy πθ.
In practice we would sample an action from the output of a network, apply this
action in an environment, and then use log_prob
to construct an equivalent
loss function. Note that we use a negative because optimizers use gradient
descent, whilst the rule above assumes gradient ascent. With a categorical
policy, the code for implementing REINFORCE would be as follows:
probs = policy_network(state)
# Note that this is equivalent to what used to be called multinomial
m = Categorical(probs)
action = m.sample()
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward
loss.backward()
Pathwise derivative¶
The other way to implement these stochastic/policy gradients would be to use the
reparameterization trick from the
rsample()
method, where the
parameterized random variable can be constructed via a parameterized
deterministic function of a parameter-free random variable. The reparameterized
sample therefore becomes differentiable. The code for implementing the pathwise
derivative would be as follows:
params = policy_network(state)
m = Normal(*params)
# Any distribution with .has_rsample == True could work based on the application
action = m.rsample()
next_state, reward = env.step(action) # Assuming that reward is differentiable
loss = -reward
loss.backward()
Distribution¶
-
class
torch.distributions.distribution.
Distribution
(batch_shape=torch.Size([]), event_shape=torch.Size([]), validate_args=None)[source]¶ Bases:
object
Distribution is the abstract base class for probability distributions.
-
arg_constraints
¶ Returns a dictionary from argument names to
Constraint
objects that should be satisfied by each argument of this distribution. Args that are not tensors need not appear in this dict.
-
batch_shape
¶ Returns the shape over which parameters are batched.
-
cdf
(value)[source]¶ Returns the cumulative density/mass function evaluated at value.
Parameters: value (Tensor) –
-
entropy
()[source]¶ Returns entropy of distribution, batched over batch_shape.
Returns: Tensor of shape batch_shape.
-
enumerate_support
()[source]¶ Returns tensor containing all values supported by a discrete distribution. The result will enumerate over dimension 0, so the shape of the result will be (cardinality,) + batch_shape + event_shape (where event_shape = () for univariate distributions).
Note that this enumerates over all batched tensors in lock-step [[0, 0], [1, 1], ...]. To iterate over the full Cartesian product use itertools.product(m.enumerate_support()).
Returns: Tensor iterating over dimension 0.
-
event_shape
¶ Returns the shape of a single sample (without batching).
-
icdf
(value)[source]¶ Returns the inverse cumulative density/mass function evaluated at value.
Parameters: value (Tensor) –
-
log_prob
(value)[source]¶ Returns the log of the probability density/mass function evaluated at value.
Parameters: value (Tensor) –
-
mean
¶ Returns the mean of the distribution.
-
perplexity
()[source]¶ Returns perplexity of distribution, batched over batch_shape.
Returns: Tensor of shape batch_shape.
-
rsample
(sample_shape=torch.Size([]))[source]¶ Generates a sample_shape shaped reparameterized sample or sample_shape shaped batch of reparameterized samples if the distribution parameters are batched.
-
sample
(sample_shape=torch.Size([]))[source]¶ Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.
-
sample_n
(n)[source]¶ Generates n samples or n batches of samples if the distribution parameters are batched.
-
stddev
¶ Returns the standard deviation of the distribution.
-
support
¶ Returns a
Constraint
object representing this distribution’s support.
-
variance
¶ Returns the variance of the distribution.
-
ExponentialFamily¶
-
class
torch.distributions.exp_family.
ExponentialFamily
(batch_shape=torch.Size([]), event_shape=torch.Size([]), validate_args=None)[source]¶ Bases:
torch.distributions.distribution.Distribution
ExponentialFamily is the abstract base class for probability distributions belonging to an exponential family, whose probability mass/density function has the form is defined below
pF(x;θ)=exp(⟨t(x),θ⟩)−F(θ)+k(x))where θ denotes the natural parameters, t(x) denotes the sufficient statistic, F(θ) is the log normalizer function for a given family and k(x) is the carrier measure.
Note
This class is an intermediary between the Distribution class and distributions which belong to an exponential family mainly to check the correctness of the .entropy() and analytic KL divergence methods. We use this class to compute the entropy and KL divergence using the AD frame- work and Bregman divergences (courtesy of: Frank Nielsen and Richard Nock, Entropies and Cross-entropies of Exponential Families).
Bernoulli¶
-
class
torch.distributions.bernoulli.
Bernoulli
(probs=None, logits=None, validate_args=None)[source]¶ Bases:
torch.distributions.exp_family.ExponentialFamily
Creates a Bernoulli distribution parameterized by probs or logits.
Samples are binary (0 or 1). They take the value 1 with probability p and 0 with probability 1 - p.
Example:
>>> m = Bernoulli(torch.tensor([0.3])) >>> m.sample() # 30% chance 1; 70% chance 0 0.0 [torch.FloatTensor of size 1]
Parameters: -
arg_constraints
= {'probs': <torch.distributions.constraints._Interval object>}¶
-
has_enumerate_support
= True¶
-
mean
¶
-
param_shape
¶
-
support
= <torch.distributions.constraints._Boolean object>¶
-
variance
¶
-
Beta¶
-
class
torch.distributions.beta.
Beta
(concentration1, concentration0, validate_args=None)[source]¶ Bases:
torch.distributions.exp_family.ExponentialFamily
Beta distribution parameterized by concentration1 and concentration0.
Example:
>>> m = Beta(torch.tensor([0.5]), torch.tensor([0.5])) >>> m.sample() # Beta distributed with concentration concentration1 and concentration0 0.1046 [torch.FloatTensor of size 1]
Parameters: -
arg_constraints
= {'concentration1': <torch.distributions.constraints._GreaterThan object>, 'concentration0': <torch.distributions.constraints._GreaterThan object>}¶
-
concentration0
¶
-
concentration1
¶
-
has_rsample
= True¶
-
mean
¶
-
support
= <torch.distributions.constraints._Interval object>¶
-
variance
¶
-
Binomial¶
-
class
torch.distributions.binomial.
Binomial
(total_count=1, probs=None, logits=None, validate_args=None)[source]¶ Bases:
torch.distributions.distribution.Distribution
Creates a Binomial distribution parameterized by total_count and either probs or logits (but not both).
- Requires a single shared total_count for all parameters and samples.
Example:
>>> m = Binomial(100, torch.tensor([0 , .2, .8, 1])) >>> x = m.sample() 0 22 71 100 [torch.FloatTensor of size 4]]
Parameters: -
arg_constraints
= {'probs': <torch.distributions.constraints._Interval object>}¶
-
has_enumerate_support
= True¶
-
mean
¶
-
param_shape
¶
-
support
¶
-
variance
¶
Categorical¶
-
class
torch.distributions.categorical.
Categorical
(probs=None, logits=None, validate_args=None)[source]¶ Bases:
torch.distributions.distribution.Distribution
Creates a categorical distribution parameterized by either
probs
orlogits
(but not both).Note
It is equivalent to the distribution that
torch.multinomial()
samples from.Samples are integers from 0 ... K-1 where K is probs.size(-1).
If
probs
is 1D with length-K, each element is the relative probability of sampling the class at that index.If
probs
is 2D, it is treated as a batch of relative probability vectors.Note
probs
will be normalized to be summing to 1.See also:
torch.multinomial()
Example:
>>> m = Categorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ])) >>> m.sample() # equal probability of 0, 1, 2, 3 3 [torch.LongTensor of size 1]
Parameters: -
arg_constraints
= {'probs': <torch.distributions.constraints._Simplex object>}¶
-
has_enumerate_support
= True¶
-
mean
¶
-
param_shape
¶
-
support
¶
-
variance
¶
-
Cauchy¶
-
class
torch.distributions.cauchy.
Cauchy
(loc, scale, validate_args=None)[source]¶ Bases:
torch.distributions.distribution.Distribution
Samples from a Cauchy (Lorentz) distribution. The distribution of the ratio of independent normally distributed random variables with means 0 follows a Cauchy distribution.
Example:
>>> m = Cauchy(torch.tensor([0.0]), torch.tensor([1.0])) >>> m.sample() # sample from a Cauchy distribution with loc=0 and scale=1 2.3214 [torch.FloatTensor of size 1]
Parameters: -
arg_constraints
= {'loc': <torch.distributions.constraints._Real object>, 'scale': <torch.distributions.constraints._GreaterThan object>}¶
-
has_rsample
= True¶
-
mean
¶
-
support
= <torch.distributions.constraints._Real object>¶
-
variance
¶
-
Chi2¶
-
class
torch.distributions.chi2.
Chi2
(df, validate_args=None)[source]¶ Bases:
torch.distributions.gamma.Gamma
Creates a Chi2 distribution parameterized by shape parameter df. This is exactly equivalent to Gamma(alpha=0.5*df, beta=0.5)
Example:
>>> m = Chi2(torch.tensor([1.0])) >>> m.sample() # Chi2 distributed with shape df=1 0.1046 [torch.FloatTensor of size 1]
Parameters: df (float or Tensor) – shape parameter of the distribution -
arg_constraints
= {'df': <torch.distributions.constraints._GreaterThan object>}¶
-
df
¶
-
Dirichlet¶
-
class
torch.distributions.dirichlet.
Dirichlet
(concentration, validate_args=None)[source]¶ Bases:
torch.distributions.exp_family.ExponentialFamily
Creates a Dirichlet distribution parameterized by concentration concentration.
Example:
>>> m = Dirichlet(torch.tensor([0.5, 0.5])) >>> m.sample() # Dirichlet distributed with concentrarion concentration 0.1046 0.8954 [torch.FloatTensor of size 2]
Parameters: concentration (Tensor) – concentration parameter of the distribution (often referred to as alpha) -
arg_constraints
= {'concentration': <torch.distributions.constraints._GreaterThan object>}¶
-
has_rsample
= True¶
-
mean
¶
-
support
= <torch.distributions.constraints._Simplex object>¶
-
variance
¶
-
Exponential¶
-
class
torch.distributions.exponential.
Exponential
(rate, validate_args=None)[source]¶ Bases:
torch.distributions.exp_family.ExponentialFamily
Creates a Exponential distribution parameterized by rate.
Example:
>>> m = Exponential(torch.tensor([1.0])) >>> m.sample() # Exponential distributed with rate=1 0.1046 [torch.FloatTensor of size 1]
Parameters: rate (float or Tensor) – rate = 1 / scale of the distribution -
arg_constraints
= {'rate': <torch.distributions.constraints._GreaterThan object>}¶
-
has_rsample
= True¶
-
mean
¶
-
stddev
¶
-
support
= <torch.distributions.constraints._GreaterThan object>¶
-
variance
¶
-
FisherSnedecor¶
-
class
torch.distributions.fishersnedecor.
FisherSnedecor
(df1, df2, validate_args=None)[source]¶ Bases:
torch.distributions.distribution.Distribution
Creates a Fisher-Snedecor distribution parameterized by df1 and df2.
Example:
>>> m = FisherSnedecor(torch.tensor([1.0]), torch.tensor([2.0])) >>> m.sample() # Fisher-Snedecor-distributed with df1=1 and df2=2 0.2453 [torch.FloatTensor of size 1]
Parameters: -
arg_constraints
= {'df1': <torch.distributions.constraints._GreaterThan object>, 'df2': <torch.distributions.constraints._GreaterThan object>}¶
-
has_rsample
= True¶
-
mean
¶
-
support
= <torch.distributions.constraints._GreaterThan object>¶
-
variance
¶
-
Gamma¶
-
class
torch.distributions.gamma.
Gamma
(concentration, rate, validate_args=None)[source]¶ Bases:
torch.distributions.exp_family.ExponentialFamily
Creates a Gamma distribution parameterized by shape concentration and rate.
Example:
>>> m = Gamma(torch.tensor([1.0]), torch.tensor([1.0])) >>> m.sample() # Gamma distributed with concentration=1 and rate=1 0.1046 [torch.FloatTensor of size 1]
Parameters: -
arg_constraints
= {'concentration': <torch.distributions.constraints._GreaterThan object>, 'rate': <torch.distributions.constraints._GreaterThan object>}¶
-
has_rsample
= True¶
-
mean
¶
-
support
= <torch.distributions.constraints._GreaterThan object>¶
-
variance
¶
-
Geometric¶
-
class
torch.distributions.geometric.
Geometric
(probs=None, logits=None, validate_args=None)[source]¶ Bases:
torch.distributions.distribution.Distribution
Creates a Geometric distribution parameterized by probs, where probs is the probability of success of Bernoulli trials. It represents the probability that in k + 1 Bernoulli trials, the first k trials failed, before seeing a success.
Samples are non-negative integers [0, inf).
Example:
>>> m = Geometric(torch.tensor([0.3])) >>> m.sample() # underlying Bernoulli has 30% chance 1; 70% chance 0 2 [torch.FloatTensor of size 1]
Parameters: -
arg_constraints
= {'probs': <torch.distributions.constraints._Interval object>}¶
-
mean
¶
-
support
= <torch.distributions.constraints._IntegerGreaterThan object>¶
-
variance
¶
-
Gumbel¶
-
class
torch.distributions.gumbel.
Gumbel
(loc, scale, validate_args=None)[source]¶ Bases:
torch.distributions.transformed_distribution.TransformedDistribution
Samples from a Gumbel Distribution.
Examples:
>>> m = Gumbel(torch.tensor([1.0]), torch.tensor([2.0])) >>> m.sample() # sample from Gumbel distribution with loc=1, scale=2 1.0124 [torch.FloatTensor of size 1]
Parameters: -
arg_constraints
= {'loc': <torch.distributions.constraints._Real object>, 'scale': <torch.distributions.constraints._GreaterThan object>}¶
-
mean
¶
-
stddev
¶
-
support
= <torch.distributions.constraints._Real object>¶
-
variance
¶
-
Independent¶
-
class
torch.distributions.independent.
Independent
(base_distribution, reinterpreted_batch_ndims, validate_args=None)[source]¶ Bases:
torch.distributions.distribution.Distribution
Reinterprets some of the batch dims of a distribution as event dims.
This is mainly useful for changing the shape of the result of
log_prob()
. For example to create a diagonal Normal distribution with the same shape as a Multivariate Normal distribution (so they are interchangeable), you can:>>> loc = torch.zeros(3) >>> scale = torch.ones(3) >>> mvn = MultivariateNormal(loc, scale_tril=torch.diag(scale)) >>> [mvn.batch_shape, mvn.event_shape] [torch.Size(()), torch.Size((3,))] >>> normal = Normal(loc, scale) >>> [normal.batch_shape, normal.event_shape] [torch.Size((3,)), torch.Size(())] >>> diagn = Independent(normal, 1) >>> [diagn.batch_shape, diagn.event_shape] [torch.Size(()), torch.Size((3,))]
Parameters: - base_distribution (torch.distributions.distribution.Distribution) – a base distribution
- reinterpreted_batch_ndims (int) – the number of batch dims to reinterpret as event dims
-
arg_constraints
= {}¶
-
has_enumerate_support
¶
-
has_rsample
¶
-
mean
¶
-
support
¶
-
variance
¶
Laplace¶
-
class
torch.distributions.laplace.
Laplace
(loc, scale, validate_args=None)[source]¶ Bases:
torch.distributions.distribution.Distribution
Creates a Laplace distribution parameterized by loc and ‘scale’.
Example:
>>> m = Laplace(torch.tensor([0.0]), torch.tensor([1.0])) >>> m.sample() # Laplace distributed with loc=0, scale=1 0.1046 [torch.FloatTensor of size 1]
Parameters: -
arg_constraints
= {'loc': <torch.distributions.constraints._Real object>, 'scale': <torch.distributions.constraints._GreaterThan object>}¶
-
has_rsample
= True¶
-
mean
¶
-
stddev
¶
-
support
= <torch.distributions.constraints._Real object>¶
-
variance
¶
-
LogNormal¶
-
class
torch.distributions.log_normal.
LogNormal
(loc, scale, validate_args=None)[source]¶ Bases:
torch.distributions.transformed_distribution.TransformedDistribution
Creates a log-normal distribution parameterized by loc and scale where:
X ~ Normal(loc, scale) Y = exp(X) ~ LogNormal(loc, scale)
Example:
>>> m = LogNormal(torch.tensor([0.0]), torch.tensor([1.0])) >>> m.sample() # log-normal distributed with mean=0 and stddev=1 0.1046 [torch.FloatTensor of size 1]
Parameters: -
arg_constraints
= {'loc': <torch.distributions.constraints._Real object>, 'scale': <torch.distributions.constraints._GreaterThan object>}¶
-
has_rsample
= True¶
-
loc
¶
-
mean
¶
-
scale
¶
-
support
= <torch.distributions.constraints._GreaterThan object>¶
-
variance
¶
-
Multinomial¶
-
class
torch.distributions.multinomial.
Multinomial
(total_count=1, probs=None, logits=None, validate_args=None)[source]¶ Bases:
torch.distributions.distribution.Distribution
Creates a Multinomial distribution parameterized by total_count and either probs or logits (but not both). The innermost dimension of probs indexes over categories. All other dimensions index over batches.
Note that total_count need not be specified if only
log_prob()
is called (see example below)Note
probs
will be normalized to be summing to 1.sample()
requires a single shared total_count for all parameters and samples.log_prob()
allows different total_count for each parameter and sample.
Example:
>>> m = Multinomial(100, torch.tensor([ 1, 1, 1, 1])) >>> x = m.sample() # equal probability of 0, 1, 2, 3 21 24 30 25 [torch.FloatTensor of size 4]] >>> Multinomial(probs=torch.tensor([1, 1, 1, 1])).log_prob(x) -4.1338 [torch.FloatTensor of size 1]
Parameters: -
arg_constraints
= {'logits': <torch.distributions.constraints._Real object>}¶
-
logits
¶
-
mean
¶
-
param_shape
¶
-
probs
¶
-
support
¶
-
variance
¶
MultivariateNormal¶
-
class
torch.distributions.multivariate_normal.
MultivariateNormal
(loc, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None)[source]¶ Bases:
torch.distributions.distribution.Distribution
Creates a multivariate normal (also called Gaussian) distribution parameterized by a mean vector and a covariance matrix.
The multivariate normal distribution can be parameterized either in terms of a positive definite covariance matrix Σ or a positive definite precition matrix Σ−1 or a lower-triangular matrix L with positive-valued diagonal entries, such that Σ=LL⊤. This triangular matrix can be obtained via e.g. Cholesky decomposition of the covariance.
Example
>>> m = MultivariateNormal(torch.zeros(2), torch.eye(2)) >>> m.sample() # normally distributed with mean=`[0,0]` and covariance_matrix=`I` -0.2102 -0.5429 [torch.FloatTensor of size 2]
Parameters: Note
Only one of
covariance_matrix
orprecision_matrix
orscale_tril
can be specified.Using
scale_tril
will be more efficient: all computations internally are based onscale_tril
. Ifcovariance_matrix
orprecision_matrix
is passed instead, it is only used to compute the corresponding lower triangular matrices using a Cholesky decomposition.-
arg_constraints
= {'loc': <torch.distributions.constraints._RealVector object>, 'covariance_matrix': <torch.distributions.constraints._PositiveDefinite object>, 'precision_matrix': <torch.distributions.constraints._PositiveDefinite object>, 'scale_tril': <torch.distributions.constraints._LowerCholesky object>}¶
-
has_rsample
= True¶
-
mean
¶
-
support
= <torch.distributions.constraints._Real object>¶
-
variance
¶
-
Normal¶
-
class
torch.distributions.normal.
Normal
(loc, scale, validate_args=None)[source]¶ Bases:
torch.distributions.exp_family.ExponentialFamily
Creates a normal (also called Gaussian) distribution parameterized by loc and scale.
Example:
>>> m = Normal(torch.tensor([0.0]), torch.tensor([1.0])) >>> m.sample() # normally distributed with loc=0 and scale=1 0.1046 [torch.FloatTensor of size 1]
Parameters: -
arg_constraints
= {'loc': <torch.distributions.constraints._Real object>, 'scale': <torch.distributions.constraints._GreaterThan object>}¶
-
has_rsample
= True¶
-
mean
¶
-
stddev
¶
-
support
= <torch.distributions.constraints._Real object>¶
-
variance
¶
-
OneHotCategorical¶
-
class
torch.distributions.one_hot_categorical.
OneHotCategorical
(probs=None, logits=None, validate_args=None)[source]¶ Bases:
torch.distributions.distribution.Distribution
Creates a one-hot categorical distribution parameterized by
probs
orlogits
.Samples are one-hot coded vectors of size
probs.size(-1)
.Note
probs
will be normalized to be summing to 1.See also:
torch.distributions.Categorical()
for specifications ofprobs
andlogits
.Example:
>>> m = OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ])) >>> m.sample() # equal probability of 0, 1, 2, 3 0 0 1 0 [torch.FloatTensor of size 4]
Parameters: -
arg_constraints
= {'probs': <torch.distributions.constraints._Simplex object>}¶
-
has_enumerate_support
= True¶
-
logits
¶
-
mean
¶
-
param_shape
¶
-
probs
¶
-
support
= <torch.distributions.constraints._Simplex object>¶
-
variance
¶
-
Pareto¶
-
class
torch.distributions.pareto.
Pareto
(scale, alpha, validate_args=None)[source]¶ Bases:
torch.distributions.transformed_distribution.TransformedDistribution
Samples from a Pareto Type 1 distribution.
Example:
>>> m = Pareto(torch.tensor([1.0]), torch.tensor([1.0])) >>> m.sample() # sample from a Pareto distribution with scale=1 and alpha=1 1.5623 [torch.FloatTensor of size 1]
Parameters: -
arg_constraints
= {'alpha': <torch.distributions.constraints._GreaterThan object>, 'scale': <torch.distributions.constraints._GreaterThan object>}¶
-
mean
¶
-
support
¶
-
variance
¶
-
Poisson¶
-
class
torch.distributions.poisson.
Poisson
(rate, validate_args=None)[source]¶ Bases:
torch.distributions.exp_family.ExponentialFamily
Creates a Poisson distribution parameterized by rate, the rate parameter.
Samples are nonnegative integers, with a pmf given by $rate^k e^{-rate}/k!$
Example:
>>> m = Poisson(torch.tensor([4])) >>> m.sample() 3 [torch.LongTensor of size 1]
Parameters: rate (Number, Tensor) – the rate parameter -
arg_constraints
= {'rate': <torch.distributions.constraints._GreaterThan object>}¶
-
mean
¶
-
support
= <torch.distributions.constraints._IntegerGreaterThan object>¶
-
variance
¶
-
RelaxedBernoulli¶
-
class
torch.distributions.relaxed_bernoulli.
RelaxedBernoulli
(temperature, probs=None, logits=None, validate_args=None)[source]¶ Bases:
torch.distributions.transformed_distribution.TransformedDistribution
Creates a RelaxedBernoulli distribution, parametrized by temperature, and either probs or logits. This is a relaxed version of the Bernoulli distribution, so the values are in (0, 1), and has reparametrizable samples.
Example:
>>> m = RelaxedBernoulli(torch.tensor([2.2]), torch.tensor([0.1, 0.2, 0.3, 0.99])) >>> m.sample() 0.2951 0.3442 0.8918 0.9021 [torch.FloatTensor of size 4]
Parameters: -
arg_constraints
= {'probs': <torch.distributions.constraints._Interval object>}¶
-
has_rsample
= True¶
-
logits
¶
-
probs
¶
-
support
= <torch.distributions.constraints._Interval object>¶
-
temperature
¶
-
RelaxedOneHotCategorical¶
-
class
torch.distributions.relaxed_categorical.
RelaxedOneHotCategorical
(temperature, probs=None, logits=None, validate_args=None)[source]¶ Bases:
torch.distributions.transformed_distribution.TransformedDistribution
Creates a RelaxedOneHotCategorical distribution parametrized by temperature and either probs or logits. This is a relaxed version of the OneHotCategorical distribution, so its values are on simplex, and has reparametrizable samples.
Example:
>>> m = RelaxedOneHotCategorical(torch.tensor([2.2]), torch.tensor([0.1, 0.2, 0.3, 0.4])) >>> m.sample() # equal probability of 1, 1, 2, 3 0.1294 0.2324 0.3859 0.2523 [torch.FloatTensor of size 4]
Parameters: -
arg_constraints
= {'probs': <torch.distributions.constraints._Simplex object>}¶
-
has_rsample
= True¶
-
logits
¶
-
probs
¶
-
support
= <torch.distributions.constraints._Simplex object>¶
-
temperature
¶
-
StudentT¶
-
class
torch.distributions.studentT.
StudentT
(df, loc=0.0, scale=1.0, validate_args=None)[source]¶ Bases:
torch.distributions.distribution.Distribution
Creates a Student’s t-distribution parameterized by df.
Example:
>>> m = StudentT(torch.tensor([2.0])) >>> m.sample() # Student's t-distributed with degrees of freedom=2 0.1046 [torch.FloatTensor of size 1]
Parameters: df (float or Tensor) – degrees of freedom -
arg_constraints
= {'df': <torch.distributions.constraints._GreaterThan object>, 'loc': <torch.distributions.constraints._Real object>, 'scale': <torch.distributions.constraints._GreaterThan object>}¶
-
has_rsample
= True¶
-
mean
¶
-
support
= <torch.distributions.constraints._Real object>¶
-
variance
¶
-
TransformedDistribution¶
-
class
torch.distributions.transformed_distribution.
TransformedDistribution
(base_distribution, transforms, validate_args=None)[source]¶ Bases:
torch.distributions.distribution.Distribution
Extension of the Distribution class, which applies a sequence of Transforms to a base distribution. Let f be the composition of transforms applied:
X ~ BaseDistribution Y = f(X) ~ TransformedDistribution(BaseDistribution, f) log p(Y) = log p(X) + log |det (dX/dY)|
Note that the
.event_shape
of aTransformedDistribution
is the maximum shape of its base distribution and its transforms, since transforms can introduce correlations among events.-
arg_constraints
= {}¶
-
cdf
(value)[source]¶ Computes the cumulative distribution function by inverting the transform(s) and computing the score of the base distribution.
-
has_rsample
¶
-
icdf
(value)[source]¶ Computes the inverse cumulative distribution function using transform(s) and computing the score of the base distribution.
-
log_prob
(value)[source]¶ Scores the sample by inverting the transform(s) and computing the score using the score of the base distribution and the log abs det jacobian.
-
rsample
(sample_shape=torch.Size([]))[source]¶ Generates a sample_shape shaped reparameterized sample or sample_shape shaped batch of reparameterized samples if the distribution parameters are batched. Samples first from base distribution and applies transform() for every transform in the list.
-
sample
(sample_shape=torch.Size([]))[source]¶ Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched. Samples first from base distribution and applies transform() for every transform in the list.
-
support
¶
-
Uniform¶
-
class
torch.distributions.uniform.
Uniform
(low, high, validate_args=None)[source]¶ Bases:
torch.distributions.distribution.Distribution
Generates uniformly distributed random samples from the half-open interval [low, high).
Example:
>>> m = Uniform(torch.tensor([0.0]), torch.tensor([5.0])) >>> m.sample() # uniformly distributed in the range [0.0, 5.0) 2.3418 [torch.FloatTensor of size 1]
Parameters: -
arg_constraints
= {'low': <torch.distributions.constraints._Dependent object>, 'high': <torch.distributions.constraints._Dependent object>}¶
-
has_rsample
= True¶
-
mean
¶
-
stddev
¶
-
support
¶
-
variance
¶
-
KL Divergence¶
-
torch.distributions.kl.
kl_divergence
(p, q)[source]¶ Compute Kullback-Leibler divergence KL(p‖ between two distributions.
KL(p \| q) = \int p(x) \log\frac {p(x)} {q(x)} \,dxParameters: - p (Distribution) – A
Distribution
object. - q (Distribution) – A
Distribution
object.
Returns: A batch of KL divergences of shape batch_shape.
Return type: Raises: NotImplementedError
– If the distribution types have not been registered viaregister_kl()
.- p (Distribution) – A
-
torch.distributions.kl.
register_kl
(type_p, type_q)[source]¶ Decorator to register a pairwise function with
kl_divergence()
. Usage:@register_kl(Normal, Normal) def kl_normal_normal(p, q): # insert implementation here
Lookup returns the most specific (type,type) match ordered by subclass. If the match is ambiguous, a RuntimeWarning is raised. For example to resolve the ambiguous situation:
@register_kl(BaseP, DerivedQ) def kl_version1(p, q): ... @register_kl(DerivedP, BaseQ) def kl_version2(p, q): ...
you should register a third most-specific implementation, e.g.:
register_kl(DerivedP, DerivedQ)(kl_version1) # Break the tie.
Parameters:
Transforms¶
-
class
torch.distributions.transforms.
Transform
(cache_size=0)[source]¶ Abstract class for invertable transformations with computable log det jacobians. They are primarily used in
torch.distributions.TransformedDistribution
.Caching is useful for tranforms whose inverses are either expensive or numerically unstable. Note that care must be taken with memoized values since the autograd graph may be reversed. For example while the following works with or without caching:
y = t(x) t.log_abs_det_jacobian(x, y).backward() # x will receive gradients.
However the following will error when caching due to dependency reversal:
y = t(x) z = t.inv(y) grad(z.sum(), [y]) # error because z is x
Derived classes should implement one or both of
_call()
or_inverse()
. Derived classes that set bijective=True should also implementlog_abs_det_jacobian()
.Parameters: cache_size (int) – Size of cache. If zero, no caching is done. If one, the latest single value is cached. Only 0 and 1 are supported.
Variables: - domain (
Constraint
) – The constraint representing valid inputs to this transform. - codomain (
Constraint
) – The constraint representing valid outputs to this transform which are inputs to the inverse transform. - bijective (bool) – Whether this transform is bijective. A transform
t
is bijective ifft.inv(t(x)) == x
andt(t.inv(y)) == y
for everyx
in the domain andy
in the codomain. Transforms that are not bijective should at least maintain the weaker pseudoinverse propertiest(t.inv(t(x)) == t(x)
andt.inv(t(t.inv(y))) == t.inv(y)
. - sign (int or Tensor) – For bijective univariate transforms, this should be +1 or -1 depending on whether transform is monotone increasing or decreasing.
- event_dim (int) – Number of dimensions that are correlated together in
the transform
event_shape
. This should be 0 for pointwise transforms, 1 for transforms that act jointly on vectors, 2 for transforms that act jointly on matrices, etc.
-
sign
¶ Returns the sign of the determinant of the Jacobian, if applicable. In general this only makes sense for bijective transforms.
- domain (
-
class
torch.distributions.transforms.
ComposeTransform
(parts)[source]¶ Composes multiple transforms in a chain. The transforms being composed are responsible for caching.
Parameters: parts (list of Transform
) – A list of transforms to compose.
-
class
torch.distributions.transforms.
ExpTransform
(cache_size=0)[source]¶ Transform via the mapping y = \exp(x).
-
class
torch.distributions.transforms.
PowerTransform
(exponent, cache_size=0)[source]¶ Transform via the mapping y = x^{\text{exponent}}.
-
class
torch.distributions.transforms.
SigmoidTransform
(cache_size=0)[source]¶ Transform via the mapping y = \frac{1}{1 + \exp(-x)} and x = \text{logit}(y).
-
class
torch.distributions.transforms.
AbsTransform
(cache_size=0)[source]¶ Transform via the mapping y = |x|.
-
class
torch.distributions.transforms.
AffineTransform
(loc, scale, event_dim=0, cache_size=0)[source]¶ Transform via the pointwise affine mapping y = \text{loc} + \text{scale} \times x.
Parameters:
-
class
torch.distributions.transforms.
SoftmaxTransform
(cache_size=0)[source]¶ Transform from unconstrained space to the simplex via y = \exp(x) then normalizing.
This is not bijective and cannot be used for HMC. However this acts mostly coordinate-wise (except for the final normalization), and thus is appropriate for coordinate-wise optimization algorithms.
-
class
torch.distributions.transforms.
StickBreakingTransform
(cache_size=0)[source]¶ Transform from unconstrained space to the simplex of one additional dimension via a stick-breaking process.
This transform arises as an iterated sigmoid transform in a stick-breaking construction of the Dirichlet distribution: the first logit is transformed via sigmoid to the first probability and the probability of everything else, and then the process recurses.
This is bijective and appropriate for use in HMC; however it mixes coordinates together and is less appropriate for optimization.
Constraints¶
The following constraints are implemented:
constraints.boolean
constraints.dependent
constraints.greater_than(lower_bound)
constraints.integer_interval(lower_bound, upper_bound)
constraints.interval(lower_bound, upper_bound)
constraints.lower_cholesky
constraints.lower_triangular
constraints.nonnegative_integer
constraints.positive
constraints.positive_definite
constraints.positive_integer
constraints.real
constraints.real_vector
constraints.simplex
constraints.unit_interval
-
class
torch.distributions.constraints.
Constraint
[source]¶ Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.
-
torch.distributions.constraints.
dependent_property
¶ alias of
_DependentProperty
-
torch.distributions.constraints.
integer_interval
¶ alias of
_IntegerInterval
-
torch.distributions.constraints.
greater_than
¶ alias of
_GreaterThan
-
torch.distributions.constraints.
less_than
¶ alias of
_LessThan
-
torch.distributions.constraints.
interval
¶ alias of
_Interval
Constraint Registry¶
PyTorch provides two global ConstraintRegistry
objects that link
Constraint
objects to
Transform
objects. These objects both
input constraints and return transforms, but they have different guarantees on
bijectivity.
biject_to(constraint)
looks up a bijectiveTransform
fromconstraints.real
to the givenconstraint
. The returned transform is guaranteed to have.bijective = True
and should implement.log_abs_det_jacobian()
.transform_to(constraint)
looks up a not-necessarily bijectiveTransform
fromconstraints.real
to the givenconstraint
. The returned transform is not guaranteed to implement.log_abs_det_jacobian()
.
The transform_to()
registry is useful for performing unconstrained
optimization on constrained parameters of probability distributions, which are
indicated by each distribution’s .arg_constraints
dict. These transforms often
overparameterize a space in order to avoid rotation; they are thus more
suitable for coordinate-wise optimization algorithms like Adam:
loc = torch.zeros(100, requires_grad=True)
unconstrained = torch.zeros(100, requires_grad=True)
scale = transform_to(Normal.arg_constraints['scale'])(unconstrained)
loss = -Normal(loc, scale).log_prob(data).sum()
The biject_to()
registry is useful for Hamiltonian Monte Carlo, where
samples from a probability distribution with constrained .support
are
propagated in an unconstrained space, and algorithms are typically rotation
invariant.:
dist = Exponential(rate)
unconstrained = torch.zeros(100, requires_grad=True)
sample = biject_to(dist.support)(unconstrained)
potential_energy = -dist.log_prob(sample).sum()
Note
An example where transform_to
and biject_to
differ is
constraints.simplex
: transform_to(constraints.simplex)
returns a
SoftmaxTransform
that simply
exponentiates and normalizes its inputs; this is a cheap and mostly
coordinate-wise operation appropriate for algorithms like SVI. In
contrast, biject_to(constraints.simplex)
returns a
StickBreakingTransform
that
bijects its input down to a one-fewer-dimensional space; this a more
expensive less numerically stable transform but is needed for algorithms
like HMC.
The biject_to
and transform_to
objects can be extended by user-defined
constraints and transforms using their .register()
method either as a
function on singleton constraints:
transform_to.register(my_constraint, my_transform)
or as a decorator on parameterized constraints:
@transform_to.register(MyConstraintClass)
def my_factory(constraint):
assert isinstance(constraint, MyConstraintClass)
return MyTransform(constraint.param1, constraint.param2)
You can create your own registry by creating a new ConstraintRegistry
object.
-
class
torch.distributions.constraint_registry.
ConstraintRegistry
[source]¶ Registry to link constraints to transforms.
-
register
(constraint, factory=None)[source]¶ Registers a
Constraint
subclass in this registry. Usage:@my_registry.register(MyConstraintClass) def construct_transform(constraint): assert isinstance(constraint, MyConstraint) return MyTransform(constraint.arg_constraints)
Parameters: - constraint (subclass of
Constraint
) – A subclass ofConstraint
, or a singleton object of the desired class. - factory (callable) – A callable that inputs a constraint object and returns
a
Transform
object.
- constraint (subclass of
-