Probability distributions - torch.distributions¶
The distributions
package contains parameterizable probability distributions
and sampling functions. This allows the construction of stochastic computation
graphs and stochastic gradient estimators for optimization. This package
generally follows the design of the TensorFlow Distributions package.
It is not possible to directly backpropagate through random samples. However, there are two main methods for creating surrogate functions that can be backpropagated through. These are the score function estimator/likelihood ratio estimator/REINFORCE and the pathwise derivative estimator. REINFORCE is commonly seen as the basis for policy gradient methods in reinforcement learning, and the pathwise derivative estimator is commonly seen in the reparameterization trick in variational autoencoders. Whilst the score function only requires the value of samples , the pathwise derivative requires the derivative . The next sections discuss these two in a reinforcement learning example. For more details see Gradient Estimation Using Stochastic Computation Graphs .
Score function¶
When the probability density function is differentiable with respect to its
parameters, we only need sample()
and
log_prob()
to implement REINFORCE:
where are the parameters, is the learning rate, is the reward and is the probability of taking action in state given policy .
In practice we would sample an action from the output of a network, apply this
action in an environment, and then use log_prob
to construct an equivalent
loss function. Note that we use a negative because optimizers use gradient
descent, whilst the rule above assumes gradient ascent. With a categorical
policy, the code for implementing REINFORCE would be as follows:
probs = policy_network(state)
# Note that this is equivalent to what used to be called multinomial
m = Categorical(probs)
action = m.sample()
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward
loss.backward()
Pathwise derivative¶
The other way to implement these stochastic/policy gradients would be to use the
reparameterization trick from the
rsample()
method, where the
parameterized random variable can be constructed via a parameterized
deterministic function of a parameter-free random variable. The reparameterized
sample therefore becomes differentiable. The code for implementing the pathwise
derivative would be as follows:
params = policy_network(state)
m = Normal(*params)
# Any distribution with .has_rsample == True could work based on the application
action = m.rsample()
next_state, reward = env.step(action) # Assuming that reward is differentiable
loss = -reward
loss.backward()
Distribution¶
- class torch.distributions.distribution.Distribution(batch_shape=torch.Size([]), event_shape=torch.Size([]), validate_args=None)[source]¶
Bases:
object
Distribution is the abstract base class for probability distributions.
- property arg_constraints: Dict[str, Constraint]¶
Returns a dictionary from argument names to
Constraint
objects that should be satisfied by each argument of this distribution. Args that are not tensors need not appear in this dict.
- entropy()[source]¶
Returns entropy of distribution, batched over batch_shape.
- Returns
Tensor of shape batch_shape.
- Return type
- enumerate_support(expand=True)[source]¶
Returns tensor containing all values supported by a discrete distribution. The result will enumerate over dimension 0, so the shape of the result will be (cardinality,) + batch_shape + event_shape (where event_shape = () for univariate distributions).
Note that this enumerates over all batched tensors in lock-step [[0, 0], [1, 1], …]. With expand=False, enumeration happens along dim 0, but with the remaining batch dimensions being singleton dimensions, [[0], [1], ...
To iterate over the full Cartesian product use itertools.product(m.enumerate_support()).
- expand(batch_shape, _instance=None)[source]¶
Returns a new distribution instance (or populates an existing instance provided by a derived class) with batch dimensions expanded to batch_shape. This method calls
expand
on the distribution’s parameters. As such, this does not allocate new memory for the expanded distribution instance. Additionally, this does not repeat any args checking or parameter broadcasting in __init__.py, when an instance is first created.- Parameters
batch_shape (torch.Size) – the desired expanded size.
_instance – new instance provided by subclasses that need to override .expand.
- Returns
New distribution instance with batch dimensions expanded to batch_size.
- log_prob(value)[source]¶
Returns the log of the probability density/mass function evaluated at value.
- perplexity()[source]¶
Returns perplexity of distribution, batched over batch_shape.
- Returns
Tensor of shape batch_shape.
- Return type
- rsample(sample_shape=torch.Size([]))[source]¶
Generates a sample_shape shaped reparameterized sample or sample_shape shaped batch of reparameterized samples if the distribution parameters are batched.
- Return type
- sample(sample_shape=torch.Size([]))[source]¶
Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.
- Return type
- sample_n(n)[source]¶
Generates n samples or n batches of samples if the distribution parameters are batched.
- Return type
- static set_default_validate_args(value)[source]¶
Sets whether validation is enabled or disabled.
The default behavior mimics Python’s
assert
statement: validation is on by default, but is disabled if Python is run in optimized mode (viapython -O
). Validation may be expensive, so you may want to disable it once a model is working.- Parameters
value (bool) – Whether to enable validation.
- property support: Optional[Any]¶
Returns a
Constraint
object representing this distribution’s support.
ExponentialFamily¶
- class torch.distributions.exp_family.ExponentialFamily(batch_shape=torch.Size([]), event_shape=torch.Size([]), validate_args=None)[source]¶
Bases:
Distribution
ExponentialFamily is the abstract base class for probability distributions belonging to an exponential family, whose probability mass/density function has the form is defined below
where denotes the natural parameters, denotes the sufficient statistic, is the log normalizer function for a given family and is the carrier measure.
Note
This class is an intermediary between the Distribution class and distributions which belong to an exponential family mainly to check the correctness of the .entropy() and analytic KL divergence methods. We use this class to compute the entropy and KL divergence using the AD framework and Bregman divergences (courtesy of: Frank Nielsen and Richard Nock, Entropies and Cross-entropies of Exponential Families).
Bernoulli¶
- class torch.distributions.bernoulli.Bernoulli(probs=None, logits=None, validate_args=None)[source]¶
Bases:
ExponentialFamily
Creates a Bernoulli distribution parameterized by
probs
orlogits
(but not both).Samples are binary (0 or 1). They take the value 1 with probability p and 0 with probability 1 - p.
Example:
>>> m = Bernoulli(torch.tensor([0.3])) >>> m.sample() # 30% chance 1; 70% chance 0 tensor([ 0.])
- Parameters
- arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0)}¶
- has_enumerate_support = True¶
- property logits¶
- property mean¶
- property mode¶
- property param_shape¶
- property probs¶
- support = Boolean()¶
- property variance¶
Beta¶
- class torch.distributions.beta.Beta(concentration1, concentration0, validate_args=None)[source]¶
Bases:
ExponentialFamily
Beta distribution parameterized by
concentration1
andconcentration0
.Example:
>>> m = Beta(torch.tensor([0.5]), torch.tensor([0.5])) >>> m.sample() # Beta distributed with concentration concentration1 and concentration0 tensor([ 0.1046])
- Parameters
- arg_constraints = {'concentration0': GreaterThan(lower_bound=0.0), 'concentration1': GreaterThan(lower_bound=0.0)}¶
- property concentration0¶
- property concentration1¶
- has_rsample = True¶
- property mean¶
- property mode¶
- support = Interval(lower_bound=0.0, upper_bound=1.0)¶
- property variance¶
Binomial¶
- class torch.distributions.binomial.Binomial(total_count=1, probs=None, logits=None, validate_args=None)[source]¶
Bases:
Distribution
Creates a Binomial distribution parameterized by
total_count
and eitherprobs
orlogits
(but not both).total_count
must be broadcastable withprobs
/logits
.Example:
>>> m = Binomial(100, torch.tensor([0 , .2, .8, 1])) >>> x = m.sample() tensor([ 0., 22., 71., 100.]) >>> m = Binomial(torch.tensor([[5.], [10.]]), torch.tensor([0.5, 0.8])) >>> x = m.sample() tensor([[ 4., 5.], [ 7., 6.]])
- Parameters
- arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0), 'total_count': IntegerGreaterThan(lower_bound=0)}¶
- has_enumerate_support = True¶
- property logits¶
- property mean¶
- property mode¶
- property param_shape¶
- property probs¶
- property support¶
- property variance¶
Categorical¶
- class torch.distributions.categorical.Categorical(probs=None, logits=None, validate_args=None)[source]¶
Bases:
Distribution
Creates a categorical distribution parameterized by either
probs
orlogits
(but not both).Note
It is equivalent to the distribution that
torch.multinomial()
samples from.Samples are integers from where K is
probs.size(-1)
.If probs is 1-dimensional with length-K, each element is the relative probability of sampling the class at that index.
If probs is N-dimensional, the first N-1 dimensions are treated as a batch of relative probability vectors.
Note
The probs argument must be non-negative, finite and have a non-zero sum, and it will be normalized to sum to 1 along the last dimension.
probs
will return this normalized value. The logits argument will be interpreted as unnormalized log probabilities and can therefore be any real number. It will likewise be normalized so that the resulting probabilities sum to 1 along the last dimension.logits
will return this normalized value.See also:
torch.multinomial()
Example:
>>> m = Categorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ])) >>> m.sample() # equal probability of 0, 1, 2, 3 tensor(3)
- Parameters
- arg_constraints = {'logits': IndependentConstraint(Real(), 1), 'probs': Simplex()}¶
- has_enumerate_support = True¶
- property logits¶
- property mean¶
- property mode¶
- property param_shape¶
- property probs¶
- property support¶
- property variance¶
Cauchy¶
- class torch.distributions.cauchy.Cauchy(loc, scale, validate_args=None)[source]¶
Bases:
Distribution
Samples from a Cauchy (Lorentz) distribution. The distribution of the ratio of independent normally distributed random variables with means 0 follows a Cauchy distribution.
Example:
>>> m = Cauchy(torch.tensor([0.0]), torch.tensor([1.0])) >>> m.sample() # sample from a Cauchy distribution with loc=0 and scale=1 tensor([ 2.3214])
- Parameters
- arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}¶
- has_rsample = True¶
- property mean¶
- property mode¶
- support = Real()¶
- property variance¶
Chi2¶
- class torch.distributions.chi2.Chi2(df, validate_args=None)[source]¶
Bases:
Gamma
Creates a Chi-squared distribution parameterized by shape parameter
df
. This is exactly equivalent toGamma(alpha=0.5*df, beta=0.5)
Example:
>>> m = Chi2(torch.tensor([1.0])) >>> m.sample() # Chi2 distributed with shape df=1 tensor([ 0.1046])
- arg_constraints = {'df': GreaterThan(lower_bound=0.0)}¶
- property df¶
ContinuousBernoulli¶
- class torch.distributions.continuous_bernoulli.ContinuousBernoulli(probs=None, logits=None, lims=(0.499, 0.501), validate_args=None)[source]¶
Bases:
ExponentialFamily
Creates a continuous Bernoulli distribution parameterized by
probs
orlogits
(but not both).The distribution is supported in [0, 1] and parameterized by ‘probs’ (in (0,1)) or ‘logits’ (real-valued). Note that, unlike the Bernoulli, ‘probs’ does not correspond to a probability and ‘logits’ does not correspond to log-odds, but the same names are used due to the similarity with the Bernoulli. See [1] for more details.
Example:
>>> m = ContinuousBernoulli(torch.tensor([0.3])) >>> m.sample() tensor([ 0.2538])
- Parameters
[1] The continuous Bernoulli: fixing a pervasive error in variational autoencoders, Loaiza-Ganem G and Cunningham JP, NeurIPS 2019. https://arxiv.org/abs/1907.06845
- arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0)}¶
- has_rsample = True¶
- property logits¶
- property mean¶
- property param_shape¶
- property probs¶
- property stddev¶
- support = Interval(lower_bound=0.0, upper_bound=1.0)¶
- property variance¶
Dirichlet¶
- class torch.distributions.dirichlet.Dirichlet(concentration, validate_args=None)[source]¶
Bases:
ExponentialFamily
Creates a Dirichlet distribution parameterized by concentration
concentration
.Example:
>>> m = Dirichlet(torch.tensor([0.5, 0.5])) >>> m.sample() # Dirichlet distributed with concentration [0.5, 0.5] tensor([ 0.1046, 0.8954])
- Parameters
concentration (Tensor) – concentration parameter of the distribution (often referred to as alpha)
- arg_constraints = {'concentration': IndependentConstraint(GreaterThan(lower_bound=0.0), 1)}¶
- has_rsample = True¶
- property mean¶
- property mode¶
- support = Simplex()¶
- property variance¶
Exponential¶
- class torch.distributions.exponential.Exponential(rate, validate_args=None)[source]¶
Bases:
ExponentialFamily
Creates a Exponential distribution parameterized by
rate
.Example:
>>> m = Exponential(torch.tensor([1.0])) >>> m.sample() # Exponential distributed with rate=1 tensor([ 0.1046])
- arg_constraints = {'rate': GreaterThan(lower_bound=0.0)}¶
- has_rsample = True¶
- property mean¶
- property mode¶
- property stddev¶
- support = GreaterThanEq(lower_bound=0.0)¶
- property variance¶
FisherSnedecor¶
- class torch.distributions.fishersnedecor.FisherSnedecor(df1, df2, validate_args=None)[source]¶
Bases:
Distribution
Creates a Fisher-Snedecor distribution parameterized by
df1
anddf2
.Example:
>>> m = FisherSnedecor(torch.tensor([1.0]), torch.tensor([2.0])) >>> m.sample() # Fisher-Snedecor-distributed with df1=1 and df2=2 tensor([ 0.2453])
- Parameters
- arg_constraints = {'df1': GreaterThan(lower_bound=0.0), 'df2': GreaterThan(lower_bound=0.0)}¶
- has_rsample = True¶
- property mean¶
- property mode¶
- support = GreaterThan(lower_bound=0.0)¶
- property variance¶
Gamma¶
- class torch.distributions.gamma.Gamma(concentration, rate, validate_args=None)[source]¶
Bases:
ExponentialFamily
Creates a Gamma distribution parameterized by shape
concentration
andrate
.Example:
>>> m = Gamma(torch.tensor([1.0]), torch.tensor([1.0])) >>> m.sample() # Gamma distributed with concentration=1 and rate=1 tensor([ 0.1046])
- Parameters
- arg_constraints = {'concentration': GreaterThan(lower_bound=0.0), 'rate': GreaterThan(lower_bound=0.0)}¶
- has_rsample = True¶
- property mean¶
- property mode¶
- support = GreaterThanEq(lower_bound=0.0)¶
- property variance¶
Geometric¶
- class torch.distributions.geometric.Geometric(probs=None, logits=None, validate_args=None)[source]¶
Bases:
Distribution
Creates a Geometric distribution parameterized by
probs
, whereprobs
is the probability of success of Bernoulli trials.Note
torch.distributions.geometric.Geometric()
-th trial is the first success hence draws samples in , whereastorch.Tensor.geometric_()
k-th trial is the first success hence draws samples in .Example:
>>> m = Geometric(torch.tensor([0.3])) >>> m.sample() # underlying Bernoulli has 30% chance 1; 70% chance 0 tensor([ 2.])
- Parameters
- arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0)}¶
- property logits¶
- property mean¶
- property mode¶
- property probs¶
- support = IntegerGreaterThan(lower_bound=0)¶
- property variance¶
Gumbel¶
- class torch.distributions.gumbel.Gumbel(loc, scale, validate_args=None)[source]¶
Bases:
TransformedDistribution
Samples from a Gumbel Distribution.
Examples:
>>> m = Gumbel(torch.tensor([1.0]), torch.tensor([2.0])) >>> m.sample() # sample from Gumbel distribution with loc=1, scale=2 tensor([ 1.0124])
- Parameters
- arg_constraints: Dict[str, Constraint] = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}¶
- property mean¶
- property mode¶
- property stddev¶
- support = Real()¶
- property variance¶
HalfCauchy¶
- class torch.distributions.half_cauchy.HalfCauchy(scale, validate_args=None)[source]¶
Bases:
TransformedDistribution
Creates a half-Cauchy distribution parameterized by scale where:
X ~ Cauchy(0, scale) Y = |X| ~ HalfCauchy(scale)
Example:
>>> m = HalfCauchy(torch.tensor([1.0])) >>> m.sample() # half-cauchy distributed with scale=1 tensor([ 2.3214])
- arg_constraints: Dict[str, Constraint] = {'scale': GreaterThan(lower_bound=0.0)}¶
- has_rsample = True¶
- property mean¶
- property mode¶
- property scale¶
- support = GreaterThanEq(lower_bound=0.0)¶
- property variance¶
HalfNormal¶
- class torch.distributions.half_normal.HalfNormal(scale, validate_args=None)[source]¶
Bases:
TransformedDistribution
Creates a half-normal distribution parameterized by scale where:
X ~ Normal(0, scale) Y = |X| ~ HalfNormal(scale)
Example:
>>> m = HalfNormal(torch.tensor([1.0])) >>> m.sample() # half-normal distributed with scale=1 tensor([ 0.1046])
- arg_constraints: Dict[str, Constraint] = {'scale': GreaterThan(lower_bound=0.0)}¶
- has_rsample = True¶
- property mean¶
- property mode¶
- property scale¶
- support = GreaterThanEq(lower_bound=0.0)¶
- property variance¶
Independent¶
- class torch.distributions.independent.Independent(base_distribution, reinterpreted_batch_ndims, validate_args=None)[source]¶
Bases:
Distribution
Reinterprets some of the batch dims of a distribution as event dims.
This is mainly useful for changing the shape of the result of
log_prob()
. For example to create a diagonal Normal distribution with the same shape as a Multivariate Normal distribution (so they are interchangeable), you can:>>> from torch.distributions.multivariate_normal import MultivariateNormal >>> from torch.distributions.normal import Normal >>> loc = torch.zeros(3) >>> scale = torch.ones(3) >>> mvn = MultivariateNormal(loc, scale_tril=torch.diag(scale)) >>> [mvn.batch_shape, mvn.event_shape] [torch.Size([]), torch.Size([3])] >>> normal = Normal(loc, scale) >>> [normal.batch_shape, normal.event_shape] [torch.Size([3]), torch.Size([])] >>> diagn = Independent(normal, 1) >>> [diagn.batch_shape, diagn.event_shape] [torch.Size([]), torch.Size([3])]
- Parameters
base_distribution (torch.distributions.distribution.Distribution) – a base distribution
reinterpreted_batch_ndims (int) – the number of batch dims to reinterpret as event dims
- arg_constraints: Dict[str, Constraint] = {}¶
- property has_enumerate_support¶
- property has_rsample¶
- property mean¶
- property mode¶
- property support¶
- property variance¶
InverseGamma¶
- class torch.distributions.inverse_gamma.InverseGamma(concentration, rate, validate_args=None)[source]¶
Bases:
TransformedDistribution
Creates an inverse gamma distribution parameterized by
concentration
andrate
where:X ~ Gamma(concentration, rate) Y = 1 / X ~ InverseGamma(concentration, rate)
Example:
>>> m = InverseGamma(torch.tensor([2.0]), torch.tensor([3.0])) >>> m.sample() tensor([ 1.2953])
- Parameters
- arg_constraints: Dict[str, Constraint] = {'concentration': GreaterThan(lower_bound=0.0), 'rate': GreaterThan(lower_bound=0.0)}¶
- property concentration¶
- has_rsample = True¶
- property mean¶
- property mode¶
- property rate¶
- support = GreaterThan(lower_bound=0.0)¶
- property variance¶
Kumaraswamy¶
- class torch.distributions.kumaraswamy.Kumaraswamy(concentration1, concentration0, validate_args=None)[source]¶
Bases:
TransformedDistribution
Samples from a Kumaraswamy distribution.
Example:
>>> m = Kumaraswamy(torch.tensor([1.0]), torch.tensor([1.0])) >>> m.sample() # sample from a Kumaraswamy distribution with concentration alpha=1 and beta=1 tensor([ 0.1729])
- Parameters
- arg_constraints: Dict[str, Constraint] = {'concentration0': GreaterThan(lower_bound=0.0), 'concentration1': GreaterThan(lower_bound=0.0)}¶
- has_rsample = True¶
- property mean¶
- property mode¶
- support = Interval(lower_bound=0.0, upper_bound=1.0)¶
- property variance¶
LKJCholesky¶
- class torch.distributions.lkj_cholesky.LKJCholesky(dim, concentration=1.0, validate_args=None)[source]¶
Bases:
Distribution
LKJ distribution for lower Cholesky factor of correlation matrices. The distribution is controlled by
concentration
parameter to make the probability of the correlation matrix generated from a Cholesky factor proportional to . Because of that, whenconcentration == 1
, we have a uniform distribution over Cholesky factors of correlation matrices:L ~ LKJCholesky(dim, concentration) X = L @ L' ~ LKJCorr(dim, concentration)
Note that this distribution samples the Cholesky factor of correlation matrices and not the correlation matrices themselves and thereby differs slightly from the derivations in [1] for the LKJCorr distribution. For sampling, this uses the Onion method from [1] Section 3.
Example:
>>> l = LKJCholesky(3, 0.5) >>> l.sample() # l @ l.T is a sample of a correlation 3x3 matrix tensor([[ 1.0000, 0.0000, 0.0000], [ 0.3516, 0.9361, 0.0000], [-0.1899, 0.4748, 0.8593]])
- Parameters
References
[1] Generating random correlation matrices based on vines and extended onion method (2009), Daniel Lewandowski, Dorota Kurowicka, Harry Joe. Journal of Multivariate Analysis. 100. 10.1016/j.jmva.2009.04.008
- arg_constraints = {'concentration': GreaterThan(lower_bound=0.0)}¶
- support = CorrCholesky()¶
Laplace¶
- class torch.distributions.laplace.Laplace(loc, scale, validate_args=None)[source]¶
Bases:
Distribution
Creates a Laplace distribution parameterized by
loc
andscale
.Example:
>>> m = Laplace(torch.tensor([0.0]), torch.tensor([1.0])) >>> m.sample() # Laplace distributed with loc=0, scale=1 tensor([ 0.1046])
- Parameters
- arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}¶
- has_rsample = True¶
- property mean¶
- property mode¶
- property stddev¶
- support = Real()¶
- property variance¶
LogNormal¶
- class torch.distributions.log_normal.LogNormal(loc, scale, validate_args=None)[source]¶
Bases:
TransformedDistribution
Creates a log-normal distribution parameterized by
loc
andscale
where:X ~ Normal(loc, scale) Y = exp(X) ~ LogNormal(loc, scale)
Example:
>>> m = LogNormal(torch.tensor([0.0]), torch.tensor([1.0])) >>> m.sample() # log-normal distributed with mean=0 and stddev=1 tensor([ 0.1046])
- Parameters
- arg_constraints: Dict[str, Constraint] = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}¶
- has_rsample = True¶
- property loc¶
- property mean¶
- property mode¶
- property scale¶
- support = GreaterThan(lower_bound=0.0)¶
- property variance¶
LowRankMultivariateNormal¶
- class torch.distributions.lowrank_multivariate_normal.LowRankMultivariateNormal(loc, cov_factor, cov_diag, validate_args=None)[source]¶
Bases:
Distribution
Creates a multivariate normal distribution with covariance matrix having a low-rank form parameterized by
cov_factor
andcov_diag
:covariance_matrix = cov_factor @ cov_factor.T + cov_diag
Example
>>> m = LowRankMultivariateNormal(torch.zeros(2), torch.tensor([[1.], [0.]]), torch.ones(2)) >>> m.sample() # normally distributed with mean=`[0,0]`, cov_factor=`[[1],[0]]`, cov_diag=`[1,1]` tensor([-0.2102, -0.5429])
- Parameters
loc (Tensor) – mean of the distribution with shape batch_shape + event_shape
cov_factor (Tensor) – factor part of low-rank form of covariance matrix with shape batch_shape + event_shape + (rank,)
cov_diag (Tensor) – diagonal part of low-rank form of covariance matrix with shape batch_shape + event_shape
Note
The computation for determinant and inverse of covariance matrix is avoided when cov_factor.shape[1] << cov_factor.shape[0] thanks to Woodbury matrix identity and matrix determinant lemma. Thanks to these formulas, we just need to compute the determinant and inverse of the small size “capacitance” matrix:
capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor
- arg_constraints = {'cov_diag': IndependentConstraint(GreaterThan(lower_bound=0.0), 1), 'cov_factor': IndependentConstraint(Real(), 2), 'loc': IndependentConstraint(Real(), 1)}¶
- property covariance_matrix¶
- has_rsample = True¶
- property mean¶
- property mode¶
- property precision_matrix¶
- property scale_tril¶
- support = IndependentConstraint(Real(), 1)¶
- property variance¶
MixtureSameFamily¶
- class torch.distributions.mixture_same_family.MixtureSameFamily(mixture_distribution, component_distribution, validate_args=None)[source]¶
Bases:
Distribution
The MixtureSameFamily distribution implements a (batch of) mixture distribution where all component are from different parameterizations of the same distribution type. It is parameterized by a Categorical “selecting distribution” (over k component) and a component distribution, i.e., a Distribution with a rightmost batch shape (equal to [k]) which indexes each (batch of) component.
Examples:
>>> # Construct Gaussian Mixture Model in 1D consisting of 5 equally >>> # weighted normal distributions >>> mix = D.Categorical(torch.ones(5,)) >>> comp = D.Normal(torch.randn(5,), torch.rand(5,)) >>> gmm = MixtureSameFamily(mix, comp) >>> # Construct Gaussian Mixture Model in 2D consisting of 5 equally >>> # weighted bivariate normal distributions >>> mix = D.Categorical(torch.ones(5,)) >>> comp = D.Independent(D.Normal( ... torch.randn(5,2), torch.rand(5,2)), 1) >>> gmm = MixtureSameFamily(mix, comp) >>> # Construct a batch of 3 Gaussian Mixture Models in 2D each >>> # consisting of 5 random weighted bivariate normal distributions >>> mix = D.Categorical(torch.rand(3,5)) >>> comp = D.Independent(D.Normal( ... torch.randn(3,5,2), torch.rand(3,5,2)), 1) >>> gmm = MixtureSameFamily(mix, comp)
- Parameters
mixture_distribution – torch.distributions.Categorical-like instance. Manages the probability of selecting component. The number of categories must match the rightmost batch dimension of the component_distribution. Must have either scalar batch_shape or batch_shape matching component_distribution.batch_shape[:-1]
component_distribution – torch.distributions.Distribution-like instance. Right-most batch dimension indexes component.
- arg_constraints: Dict[str, Constraint] = {}¶
- property component_distribution¶
- has_rsample = False¶
- property mean¶
- property mixture_distribution¶
- property support¶
- property variance¶
Multinomial¶
- class torch.distributions.multinomial.Multinomial(total_count=1, probs=None, logits=None, validate_args=None)[source]¶
Bases:
Distribution
Creates a Multinomial distribution parameterized by
total_count
and eitherprobs
orlogits
(but not both). The innermost dimension ofprobs
indexes over categories. All other dimensions index over batches.Note that
total_count
need not be specified if onlylog_prob()
is called (see example below)Note
The probs argument must be non-negative, finite and have a non-zero sum, and it will be normalized to sum to 1 along the last dimension.
probs
will return this normalized value. The logits argument will be interpreted as unnormalized log probabilities and can therefore be any real number. It will likewise be normalized so that the resulting probabilities sum to 1 along the last dimension.logits
will return this normalized value.sample()
requires a single shared total_count for all parameters and samples.log_prob()
allows different total_count for each parameter and sample.
Example:
>>> m = Multinomial(100, torch.tensor([ 1., 1., 1., 1.])) >>> x = m.sample() # equal probability of 0, 1, 2, 3 tensor([ 21., 24., 30., 25.]) >>> Multinomial(probs=torch.tensor([1., 1., 1., 1.])).log_prob(x) tensor([-4.1338])
- Parameters
- arg_constraints = {'logits': IndependentConstraint(Real(), 1), 'probs': Simplex()}¶
- property logits¶
- property mean¶
- property param_shape¶
- property probs¶
- property support¶
- property variance¶
MultivariateNormal¶
- class torch.distributions.multivariate_normal.MultivariateNormal(loc, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None)[source]¶
Bases:
Distribution
Creates a multivariate normal (also called Gaussian) distribution parameterized by a mean vector and a covariance matrix.
The multivariate normal distribution can be parameterized either in terms of a positive definite covariance matrix or a positive definite precision matrix or a lower-triangular matrix with positive-valued diagonal entries, such that . This triangular matrix can be obtained via e.g. Cholesky decomposition of the covariance.
Example
>>> m = MultivariateNormal(torch.zeros(2), torch.eye(2)) >>> m.sample() # normally distributed with mean=`[0,0]` and covariance_matrix=`I` tensor([-0.2102, -0.5429])
- Parameters
Note
Only one of
covariance_matrix
orprecision_matrix
orscale_tril
can be specified.Using
scale_tril
will be more efficient: all computations internally are based onscale_tril
. Ifcovariance_matrix
orprecision_matrix
is passed instead, it is only used to compute the corresponding lower triangular matrices using a Cholesky decomposition.- arg_constraints = {'covariance_matrix': PositiveDefinite(), 'loc': IndependentConstraint(Real(), 1), 'precision_matrix': PositiveDefinite(), 'scale_tril': LowerCholesky()}¶
- property covariance_matrix¶
- has_rsample = True¶
- property mean¶
- property mode¶
- property precision_matrix¶
- property scale_tril¶
- support = IndependentConstraint(Real(), 1)¶
- property variance¶
NegativeBinomial¶
- class torch.distributions.negative_binomial.NegativeBinomial(total_count, probs=None, logits=None, validate_args=None)[source]¶
Bases:
Distribution
Creates a Negative Binomial distribution, i.e. distribution of the number of successful independent and identical Bernoulli trials before
total_count
failures are achieved. The probability of success of each Bernoulli trial isprobs
.- Parameters
- arg_constraints = {'logits': Real(), 'probs': HalfOpenInterval(lower_bound=0.0, upper_bound=1.0), 'total_count': GreaterThanEq(lower_bound=0)}¶
- property logits¶
- property mean¶
- property mode¶
- property param_shape¶
- property probs¶
- support = IntegerGreaterThan(lower_bound=0)¶
- property variance¶
Normal¶
- class torch.distributions.normal.Normal(loc, scale, validate_args=None)[source]¶
Bases:
ExponentialFamily
Creates a normal (also called Gaussian) distribution parameterized by
loc
andscale
.Example:
>>> m = Normal(torch.tensor([0.0]), torch.tensor([1.0])) >>> m.sample() # normally distributed with loc=0 and scale=1 tensor([ 0.1046])
- Parameters
- arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}¶
- has_rsample = True¶
- property mean¶
- property mode¶
- property stddev¶
- support = Real()¶
- property variance¶
OneHotCategorical¶
- class torch.distributions.one_hot_categorical.OneHotCategorical(probs=None, logits=None, validate_args=None)[source]¶
Bases:
Distribution
Creates a one-hot categorical distribution parameterized by
probs
orlogits
.Samples are one-hot coded vectors of size
probs.size(-1)
.Note
The probs argument must be non-negative, finite and have a non-zero sum, and it will be normalized to sum to 1 along the last dimension.
probs
will return this normalized value. The logits argument will be interpreted as unnormalized log probabilities and can therefore be any real number. It will likewise be normalized so that the resulting probabilities sum to 1 along the last dimension.logits
will return this normalized value.See also:
torch.distributions.Categorical()
for specifications ofprobs
andlogits
.Example:
>>> m = OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ])) >>> m.sample() # equal probability of 0, 1, 2, 3 tensor([ 0., 0., 0., 1.])
- Parameters
- arg_constraints = {'logits': IndependentConstraint(Real(), 1), 'probs': Simplex()}¶
- has_enumerate_support = True¶
- property logits¶
- property mean¶
- property mode¶
- property param_shape¶
- property probs¶
- support = OneHot()¶
- property variance¶
Pareto¶
- class torch.distributions.pareto.Pareto(scale, alpha, validate_args=None)[source]¶
Bases:
TransformedDistribution
Samples from a Pareto Type 1 distribution.
Example:
>>> m = Pareto(torch.tensor([1.0]), torch.tensor([1.0])) >>> m.sample() # sample from a Pareto distribution with scale=1 and alpha=1 tensor([ 1.5623])
- Parameters
- arg_constraints: Dict[str, Constraint] = {'alpha': GreaterThan(lower_bound=0.0), 'scale': GreaterThan(lower_bound=0.0)}¶
- property mean¶
- property mode¶
- property support¶
- property variance¶
Poisson¶
- class torch.distributions.poisson.Poisson(rate, validate_args=None)[source]¶
Bases:
ExponentialFamily
Creates a Poisson distribution parameterized by
rate
, the rate parameter.Samples are nonnegative integers, with a pmf given by
Example:
>>> m = Poisson(torch.tensor([4])) >>> m.sample() tensor([ 3.])
- Parameters
rate (Number, Tensor) – the rate parameter
- arg_constraints = {'rate': GreaterThanEq(lower_bound=0.0)}¶
- property mean¶
- property mode¶
- support = IntegerGreaterThan(lower_bound=0)¶
- property variance¶
RelaxedBernoulli¶
- class torch.distributions.relaxed_bernoulli.RelaxedBernoulli(temperature, probs=None, logits=None, validate_args=None)[source]¶
Bases:
TransformedDistribution
Creates a RelaxedBernoulli distribution, parametrized by
temperature
, and eitherprobs
orlogits
(but not both). This is a relaxed version of the Bernoulli distribution, so the values are in (0, 1), and has reparametrizable samples.Example:
>>> m = RelaxedBernoulli(torch.tensor([2.2]), ... torch.tensor([0.1, 0.2, 0.3, 0.99])) >>> m.sample() tensor([ 0.2951, 0.3442, 0.8918, 0.9021])
- Parameters
- arg_constraints: Dict[str, Constraint] = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0)}¶
- has_rsample = True¶
- property logits¶
- property probs¶
- support = Interval(lower_bound=0.0, upper_bound=1.0)¶
- property temperature¶
LogitRelaxedBernoulli¶
- class torch.distributions.relaxed_bernoulli.LogitRelaxedBernoulli(temperature, probs=None, logits=None, validate_args=None)[source]¶
Bases:
Distribution
Creates a LogitRelaxedBernoulli distribution parameterized by
probs
orlogits
(but not both), which is the logit of a RelaxedBernoulli distribution.Samples are logits of values in (0, 1). See [1] for more details.
- Parameters
[1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables (Maddison et al., 2017)
[2] Categorical Reparametrization with Gumbel-Softmax (Jang et al., 2017)
- arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0)}¶
- property logits¶
- property param_shape¶
- property probs¶
- support = Real()¶
RelaxedOneHotCategorical¶
- class torch.distributions.relaxed_categorical.RelaxedOneHotCategorical(temperature, probs=None, logits=None, validate_args=None)[source]¶
Bases:
TransformedDistribution
Creates a RelaxedOneHotCategorical distribution parametrized by
temperature
, and eitherprobs
orlogits
. This is a relaxed version of theOneHotCategorical
distribution, so its samples are on simplex, and are reparametrizable.Example:
>>> m = RelaxedOneHotCategorical(torch.tensor([2.2]), ... torch.tensor([0.1, 0.2, 0.3, 0.4])) >>> m.sample() tensor([ 0.1294, 0.2324, 0.3859, 0.2523])
- Parameters
- arg_constraints: Dict[str, Constraint] = {'logits': IndependentConstraint(Real(), 1), 'probs': Simplex()}¶
- has_rsample = True¶
- property logits¶
- property probs¶
- support = Simplex()¶
- property temperature¶
StudentT¶
- class torch.distributions.studentT.StudentT(df, loc=0.0, scale=1.0, validate_args=None)[source]¶
Bases:
Distribution
Creates a Student’s t-distribution parameterized by degree of freedom
df
, meanloc
and scalescale
.Example:
>>> m = StudentT(torch.tensor([2.0])) >>> m.sample() # Student's t-distributed with degrees of freedom=2 tensor([ 0.1046])
- Parameters
- arg_constraints = {'df': GreaterThan(lower_bound=0.0), 'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}¶
- has_rsample = True¶
- property mean¶
- property mode¶
- support = Real()¶
- property variance¶
TransformedDistribution¶
- class torch.distributions.transformed_distribution.TransformedDistribution(base_distribution, transforms, validate_args=None)[source]¶
Bases:
Distribution
Extension of the Distribution class, which applies a sequence of Transforms to a base distribution. Let f be the composition of transforms applied:
X ~ BaseDistribution Y = f(X) ~ TransformedDistribution(BaseDistribution, f) log p(Y) = log p(X) + log |det (dX/dY)|
Note that the
.event_shape
of aTransformedDistribution
is the maximum shape of its base distribution and its transforms, since transforms can introduce correlations among events.An example for the usage of
TransformedDistribution
would be:# Building a Logistic Distribution # X ~ Uniform(0, 1) # f = a + b * logit(X) # Y ~ f(X) ~ Logistic(a, b) base_distribution = Uniform(0, 1) transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)] logistic = TransformedDistribution(base_distribution, transforms)
For more examples, please look at the implementations of
Gumbel
,HalfCauchy
,HalfNormal
,LogNormal
,Pareto
,Weibull
,RelaxedBernoulli
andRelaxedOneHotCategorical
- arg_constraints: Dict[str, Constraint] = {}¶
- cdf(value)[source]¶
Computes the cumulative distribution function by inverting the transform(s) and computing the score of the base distribution.
- property has_rsample¶
- icdf(value)[source]¶
Computes the inverse cumulative distribution function using transform(s) and computing the score of the base distribution.
- log_prob(value)[source]¶
Scores the sample by inverting the transform(s) and computing the score using the score of the base distribution and the log abs det jacobian.
- rsample(sample_shape=torch.Size([]))[source]¶
Generates a sample_shape shaped reparameterized sample or sample_shape shaped batch of reparameterized samples if the distribution parameters are batched. Samples first from base distribution and applies transform() for every transform in the list.
- Return type
- sample(sample_shape=torch.Size([]))[source]¶
Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched. Samples first from base distribution and applies transform() for every transform in the list.
- property support¶
Uniform¶
- class torch.distributions.uniform.Uniform(low, high, validate_args=None)[source]¶
Bases:
Distribution
Generates uniformly distributed random samples from the half-open interval
[low, high)
.Example:
>>> m = Uniform(torch.tensor([0.0]), torch.tensor([5.0])) >>> m.sample() # uniformly distributed in the range [0.0, 5.0) tensor([ 2.3418])
- Parameters
- arg_constraints = {'high': Dependent(), 'low': Dependent()}¶
- has_rsample = True¶
- property mean¶
- property mode¶
- property stddev¶
- property support¶
- property variance¶
VonMises¶
- class torch.distributions.von_mises.VonMises(loc, concentration, validate_args=None)[source]¶
Bases:
Distribution
A circular von Mises distribution.
This implementation uses polar coordinates. The
loc
andvalue
args can be any real number (to facilitate unconstrained optimization), but are interpreted as angles modulo 2 pi.- Example::
>>> m = VonMises(torch.tensor([1.0]), torch.tensor([1.0])) >>> m.sample() # von Mises distributed with loc=1 and concentration=1 tensor([1.9777])
- Parameters
loc (torch.Tensor) – an angle in radians.
concentration (torch.Tensor) – concentration parameter
- arg_constraints = {'concentration': GreaterThan(lower_bound=0.0), 'loc': Real()}¶
- has_rsample = False¶
- property mean¶
The provided mean is the circular one.
- property mode¶
- sample(sample_shape=torch.Size([]))[source]¶
The sampling algorithm for the von Mises distribution is based on the following paper: D.J. Best and N.I. Fisher, “Efficient simulation of the von Mises distribution.” Applied Statistics (1979): 152-157.
Sampling is always done in double precision internally to avoid a hang in _rejection_sample() for small values of the concentration, which starts to happen for single precision around 1e-4 (see issue #88443).
- support = Real()¶
- property variance¶
The provided variance is the circular one.
Weibull¶
- class torch.distributions.weibull.Weibull(scale, concentration, validate_args=None)[source]¶
Bases:
TransformedDistribution
Samples from a two-parameter Weibull distribution.
Example
>>> m = Weibull(torch.tensor([1.0]), torch.tensor([1.0])) >>> m.sample() # sample from a Weibull distribution with scale=1, concentration=1 tensor([ 0.4784])
- Parameters
- arg_constraints: Dict[str, Constraint] = {'concentration': GreaterThan(lower_bound=0.0), 'scale': GreaterThan(lower_bound=0.0)}¶
- property mean¶
- property mode¶
- support = GreaterThan(lower_bound=0.0)¶
- property variance¶
Wishart¶
- class torch.distributions.wishart.Wishart(df, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None)[source]¶
Bases:
ExponentialFamily
Creates a Wishart distribution parameterized by a symmetric positive definite matrix , or its Cholesky decomposition
Example
>>> m = Wishart(torch.Tensor([2]), covariance_matrix=torch.eye(2)) >>> m.sample() # Wishart distributed with mean=`df * I` and >>> # variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j
- Parameters
df (float or Tensor) – real-valued parameter larger than the (dimension of Square matrix) - 1
covariance_matrix (Tensor) – positive-definite covariance matrix
precision_matrix (Tensor) – positive-definite precision matrix
scale_tril (Tensor) – lower-triangular factor of covariance, with positive-valued diagonal
Note
Only one of
covariance_matrix
orprecision_matrix
orscale_tril
can be specified. Usingscale_tril
will be more efficient: all computations internally are based onscale_tril
. Ifcovariance_matrix
orprecision_matrix
is passed instead, it is only used to compute the corresponding lower triangular matrices using a Cholesky decomposition. ‘torch.distributions.LKJCholesky’ is a restricted Wishart distribution.[1]References
[1] Wang, Z., Wu, Y. and Chu, H., 2018. On equivalence of the LKJ distribution and the restricted Wishart distribution. [2] Sawyer, S., 2007. Wishart Distributions and Inverse-Wishart Sampling. [3] Anderson, T. W., 2003. An Introduction to Multivariate Statistical Analysis (3rd ed.). [4] Odell, P. L. & Feiveson, A. H., 1966. A Numerical Procedure to Generate a SampleCovariance Matrix. JASA, 61(313):199-203. [5] Ku, Y.-C. & Bloomfield, P., 2010. Generating Random Wishart Matrices with Fractional Degrees of Freedom in OX.
- arg_constraints = {'covariance_matrix': PositiveDefinite(), 'df': GreaterThan(lower_bound=0), 'precision_matrix': PositiveDefinite(), 'scale_tril': LowerCholesky()}¶
- property covariance_matrix¶
- has_rsample = True¶
- property mean¶
- property mode¶
- property precision_matrix¶
- rsample(sample_shape=torch.Size([]), max_try_correction=None)[source]¶
Warning
In some cases, sampling algorithm based on Bartlett decomposition may return singular matrix samples. Several tries to correct singular samples are performed by default, but it may end up returning singular matrix samples. Singular samples may return -inf values in .log_prob(). In those cases, the user should validate the samples and either fix the value of df or adjust max_try_correction value for argument in .rsample accordingly.
- Return type
- property scale_tril¶
- support = PositiveDefinite()¶
- property variance¶
KL Divergence¶
- torch.distributions.kl.kl_divergence(p, q)[source]¶
Compute Kullback-Leibler divergence between two distributions.
- Parameters
p (Distribution) – A
Distribution
object.q (Distribution) – A
Distribution
object.
- Returns
A batch of KL divergences of shape batch_shape.
- Return type
- Raises
NotImplementedError – If the distribution types have not been registered via
register_kl()
.
- KL divergence is currently implemented for the following distribution pairs:
Bernoulli
andBernoulli
Bernoulli
andPoisson
Beta
andBeta
Beta
andContinuousBernoulli
Beta
andExponential
Beta
andGamma
Beta
andNormal
Beta
andPareto
Beta
andUniform
Binomial
andBinomial
Categorical
andCategorical
Cauchy
andCauchy
ContinuousBernoulli
andContinuousBernoulli
ContinuousBernoulli
andExponential
ContinuousBernoulli
andNormal
ContinuousBernoulli
andPareto
ContinuousBernoulli
andUniform
Dirichlet
andDirichlet
Exponential
andBeta
Exponential
andContinuousBernoulli
Exponential
andExponential
Exponential
andGamma
Exponential
andGumbel
Exponential
andNormal
Exponential
andPareto
Exponential
andUniform
ExponentialFamily
andExponentialFamily
Gamma
andBeta
Gamma
andContinuousBernoulli
Gamma
andExponential
Gamma
andGamma
Gamma
andGumbel
Gamma
andNormal
Gamma
andPareto
Gamma
andUniform
Geometric
andGeometric
Gumbel
andBeta
Gumbel
andContinuousBernoulli
Gumbel
andExponential
Gumbel
andGamma
Gumbel
andGumbel
Gumbel
andNormal
Gumbel
andPareto
Gumbel
andUniform
HalfNormal
andHalfNormal
Independent
andIndependent
Laplace
andBeta
Laplace
andContinuousBernoulli
Laplace
andExponential
Laplace
andGamma
Laplace
andLaplace
Laplace
andNormal
Laplace
andPareto
Laplace
andUniform
LowRankMultivariateNormal
andLowRankMultivariateNormal
LowRankMultivariateNormal
andMultivariateNormal
MultivariateNormal
andLowRankMultivariateNormal
MultivariateNormal
andMultivariateNormal
Normal
andBeta
Normal
andContinuousBernoulli
Normal
andExponential
Normal
andGamma
Normal
andGumbel
Normal
andLaplace
Normal
andNormal
Normal
andPareto
Normal
andUniform
OneHotCategorical
andOneHotCategorical
Pareto
andBeta
Pareto
andContinuousBernoulli
Pareto
andExponential
Pareto
andGamma
Pareto
andNormal
Pareto
andPareto
Pareto
andUniform
Poisson
andBernoulli
Poisson
andBinomial
Poisson
andPoisson
TransformedDistribution
andTransformedDistribution
Uniform
andBeta
Uniform
andContinuousBernoulli
Uniform
andExponential
Uniform
andGamma
Uniform
andGumbel
Uniform
andNormal
Uniform
andPareto
Uniform
andUniform
- torch.distributions.kl.register_kl(type_p, type_q)[source]¶
Decorator to register a pairwise function with
kl_divergence()
. Usage:@register_kl(Normal, Normal) def kl_normal_normal(p, q): # insert implementation here
Lookup returns the most specific (type,type) match ordered by subclass. If the match is ambiguous, a RuntimeWarning is raised. For example to resolve the ambiguous situation:
@register_kl(BaseP, DerivedQ) def kl_version1(p, q): ... @register_kl(DerivedP, BaseQ) def kl_version2(p, q): ...
you should register a third most-specific implementation, e.g.:
register_kl(DerivedP, DerivedQ)(kl_version1) # Break the tie.
Transforms¶
- class torch.distributions.transforms.AbsTransform(cache_size=0)[source]¶
Transform via the mapping .
- class torch.distributions.transforms.AffineTransform(loc, scale, event_dim=0, cache_size=0)[source]¶
Transform via the pointwise affine mapping .
- class torch.distributions.transforms.CatTransform(tseq, dim=0, lengths=None, cache_size=0)[source]¶
Transform functor that applies a sequence of transforms tseq component-wise to each submatrix at dim, of length lengths[dim], in a way compatible with
torch.cat()
.Example:
x0 = torch.cat([torch.range(1, 10), torch.range(1, 10)], dim=0) x = torch.cat([x0, x0], dim=0) t0 = CatTransform([ExpTransform(), identity_transform], dim=0, lengths=[10, 10]) t = CatTransform([t0, t0], dim=0, lengths=[20, 20]) y = t(x)
- class torch.distributions.transforms.ComposeTransform(parts, cache_size=0)[source]¶
Composes multiple transforms in a chain. The transforms being composed are responsible for caching.
- class torch.distributions.transforms.CorrCholeskyTransform(cache_size=0)[source]¶
Transforms an uncontrained real vector with length into the Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower triangular matrix with positive diagonals and unit Euclidean norm for each row. The transform is processed as follows:
First we convert x into a lower triangular matrix in row order.
For each row of the lower triangular part, we apply a signed version of class
StickBreakingTransform
to transform into a unit Euclidean length vector using the following steps: - Scales into the interval domain: . - Transforms into an unsigned domain: . - Applies . - Transforms back into signed domain: .
- class torch.distributions.transforms.CumulativeDistributionTransform(distribution, cache_size=0)[source]¶
Transform via the cumulative distribution function of a probability distribution.
- Parameters
distribution (Distribution) – Distribution whose cumulative distribution function to use for the transformation.
Example:
# Construct a Gaussian copula from a multivariate normal. base_dist = MultivariateNormal( loc=torch.zeros(2), scale_tril=LKJCholesky(2).sample(), ) transform = CumulativeDistributionTransform(Normal(0, 1)) copula = TransformedDistribution(base_dist, [transform])
- class torch.distributions.transforms.ExpTransform(cache_size=0)[source]¶
Transform via the mapping .
- class torch.distributions.transforms.IndependentTransform(base_transform, reinterpreted_batch_ndims, cache_size=0)[source]¶
Wrapper around another transform to treat
reinterpreted_batch_ndims
-many extra of the right most dimensions as dependent. This has no effect on the forward or backward transforms, but does sum outreinterpreted_batch_ndims
-many of the rightmost dimensions inlog_abs_det_jacobian()
.
- class torch.distributions.transforms.LowerCholeskyTransform(cache_size=0)[source]¶
Transform from unconstrained matrices to lower-triangular matrices with nonnegative diagonal entries.
This is useful for parameterizing positive definite matrices in terms of their Cholesky factorization.
- class torch.distributions.transforms.PositiveDefiniteTransform(cache_size=0)[source]¶
Transform from unconstrained matrices to positive-definite matrices.
- class torch.distributions.transforms.PowerTransform(exponent, cache_size=0)[source]¶
Transform via the mapping .
- class torch.distributions.transforms.ReshapeTransform(in_shape, out_shape, cache_size=0)[source]¶
Unit Jacobian transform to reshape the rightmost part of a tensor.
Note that
in_shape
andout_shape
must have the same number of elements, just as fortorch.Tensor.reshape()
.- Parameters
in_shape (torch.Size) – The input event shape.
out_shape (torch.Size) – The output event shape.
- class torch.distributions.transforms.SigmoidTransform(cache_size=0)[source]¶
Transform via the mapping and .
- class torch.distributions.transforms.SoftplusTransform(cache_size=0)[source]¶
Transform via the mapping . The implementation reverts to the linear function when .
- class torch.distributions.transforms.TanhTransform(cache_size=0)[source]¶
Transform via the mapping .
It is equivalent to
` ComposeTransform([AffineTransform(0., 2.), SigmoidTransform(), AffineTransform(-1., 2.)]) `
However this might not be numerically stable, thus it is recommended to use TanhTransform instead.Note that one should use cache_size=1 when it comes to NaN/Inf values.
- class torch.distributions.transforms.SoftmaxTransform(cache_size=0)[source]¶
Transform from unconstrained space to the simplex via then normalizing.
This is not bijective and cannot be used for HMC. However this acts mostly coordinate-wise (except for the final normalization), and thus is appropriate for coordinate-wise optimization algorithms.
- class torch.distributions.transforms.StackTransform(tseq, dim=0, cache_size=0)[source]¶
Transform functor that applies a sequence of transforms tseq component-wise to each submatrix at dim in a way compatible with
torch.stack()
.Example:
x = torch.stack([torch.range(1, 10), torch.range(1, 10)], dim=1) t = StackTransform([ExpTransform(), identity_transform], dim=1) y = t(x)
- class torch.distributions.transforms.StickBreakingTransform(cache_size=0)[source]¶
Transform from unconstrained space to the simplex of one additional dimension via a stick-breaking process.
This transform arises as an iterated sigmoid transform in a stick-breaking construction of the Dirichlet distribution: the first logit is transformed via sigmoid to the first probability and the probability of everything else, and then the process recurses.
This is bijective and appropriate for use in HMC; however it mixes coordinates together and is less appropriate for optimization.
- class torch.distributions.transforms.Transform(cache_size=0)[source]¶
Abstract class for invertable transformations with computable log det jacobians. They are primarily used in
torch.distributions.TransformedDistribution
.Caching is useful for transforms whose inverses are either expensive or numerically unstable. Note that care must be taken with memoized values since the autograd graph may be reversed. For example while the following works with or without caching:
y = t(x) t.log_abs_det_jacobian(x, y).backward() # x will receive gradients.
However the following will error when caching due to dependency reversal:
y = t(x) z = t.inv(y) grad(z.sum(), [y]) # error because z is x
Derived classes should implement one or both of
_call()
or_inverse()
. Derived classes that set bijective=True should also implementlog_abs_det_jacobian()
.- Parameters
cache_size (int) – Size of cache. If zero, no caching is done. If one, the latest single value is cached. Only 0 and 1 are supported.
- Variables
domain (
Constraint
) – The constraint representing valid inputs to this transform.codomain (
Constraint
) – The constraint representing valid outputs to this transform which are inputs to the inverse transform.bijective (bool) – Whether this transform is bijective. A transform
t
is bijective ifft.inv(t(x)) == x
andt(t.inv(y)) == y
for everyx
in the domain andy
in the codomain. Transforms that are not bijective should at least maintain the weaker pseudoinverse propertiest(t.inv(t(x)) == t(x)
andt.inv(t(t.inv(y))) == t.inv(y)
.sign (int or Tensor) – For bijective univariate transforms, this should be +1 or -1 depending on whether transform is monotone increasing or decreasing.
- property sign¶
Returns the sign of the determinant of the Jacobian, if applicable. In general this only makes sense for bijective transforms.
- log_abs_det_jacobian(x, y)[source]¶
Computes the log det jacobian log |dy/dx| given input and output.
Constraints¶
The following constraints are implemented:
constraints.boolean
constraints.cat
constraints.corr_cholesky
constraints.dependent
constraints.greater_than(lower_bound)
constraints.greater_than_eq(lower_bound)
constraints.independent(constraint, reinterpreted_batch_ndims)
constraints.integer_interval(lower_bound, upper_bound)
constraints.interval(lower_bound, upper_bound)
constraints.less_than(upper_bound)
constraints.lower_cholesky
constraints.lower_triangular
constraints.multinomial
constraints.nonnegative
constraints.nonnegative_integer
constraints.one_hot
constraints.positive_integer
constraints.positive
constraints.positive_semidefinite
constraints.positive_definite
constraints.real_vector
constraints.real
constraints.simplex
constraints.symmetric
constraints.stack
constraints.square
constraints.symmetric
constraints.unit_interval
- class torch.distributions.constraints.Constraint[source]¶
Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.
- Variables
- torch.distributions.constraints.cat¶
alias of
_Cat
- torch.distributions.constraints.dependent_property¶
alias of
_DependentProperty
- torch.distributions.constraints.greater_than¶
alias of
_GreaterThan
- torch.distributions.constraints.greater_than_eq¶
alias of
_GreaterThanEq
- torch.distributions.constraints.independent¶
alias of
_IndependentConstraint
- torch.distributions.constraints.integer_interval¶
alias of
_IntegerInterval
- torch.distributions.constraints.interval¶
alias of
_Interval
- torch.distributions.constraints.half_open_interval¶
alias of
_HalfOpenInterval
- torch.distributions.constraints.is_dependent(constraint)[source]¶
Checks if
constraint
is a_Dependent
object.- Parameters
constraint – A
Constraint
object.- Returns
True if
constraint
can be refined to the type_Dependent
, False otherwise.- Return type
bool
Examples
>>> import torch >>> from torch.distributions import Bernoulli >>> from torch.distributions.constraints import is_dependent
>>> dist = Bernoulli(probs = torch.tensor([0.6], requires_grad=True)) >>> constraint1 = dist.arg_constraints["probs"] >>> constraint2 = dist.arg_constraints["logits"]
>>> for constraint in [constraint1, constraint2]: >>> if is_dependent(constraint): >>> continue
- torch.distributions.constraints.less_than¶
alias of
_LessThan
- torch.distributions.constraints.multinomial¶
alias of
_Multinomial
- torch.distributions.constraints.stack¶
alias of
_Stack
Constraint Registry¶
PyTorch provides two global ConstraintRegistry
objects that link
Constraint
objects to
Transform
objects. These objects both
input constraints and return transforms, but they have different guarantees on
bijectivity.
biject_to(constraint)
looks up a bijectiveTransform
fromconstraints.real
to the givenconstraint
. The returned transform is guaranteed to have.bijective = True
and should implement.log_abs_det_jacobian()
.transform_to(constraint)
looks up a not-necessarily bijectiveTransform
fromconstraints.real
to the givenconstraint
. The returned transform is not guaranteed to implement.log_abs_det_jacobian()
.
The transform_to()
registry is useful for performing unconstrained
optimization on constrained parameters of probability distributions, which are
indicated by each distribution’s .arg_constraints
dict. These transforms often
overparameterize a space in order to avoid rotation; they are thus more
suitable for coordinate-wise optimization algorithms like Adam:
loc = torch.zeros(100, requires_grad=True)
unconstrained = torch.zeros(100, requires_grad=True)
scale = transform_to(Normal.arg_constraints['scale'])(unconstrained)
loss = -Normal(loc, scale).log_prob(data).sum()
The biject_to()
registry is useful for Hamiltonian Monte Carlo, where
samples from a probability distribution with constrained .support
are
propagated in an unconstrained space, and algorithms are typically rotation
invariant.:
dist = Exponential(rate)
unconstrained = torch.zeros(100, requires_grad=True)
sample = biject_to(dist.support)(unconstrained)
potential_energy = -dist.log_prob(sample).sum()
Note
An example where transform_to
and biject_to
differ is
constraints.simplex
: transform_to(constraints.simplex)
returns a
SoftmaxTransform
that simply
exponentiates and normalizes its inputs; this is a cheap and mostly
coordinate-wise operation appropriate for algorithms like SVI. In
contrast, biject_to(constraints.simplex)
returns a
StickBreakingTransform
that
bijects its input down to a one-fewer-dimensional space; this a more
expensive less numerically stable transform but is needed for algorithms
like HMC.
The biject_to
and transform_to
objects can be extended by user-defined
constraints and transforms using their .register()
method either as a
function on singleton constraints:
transform_to.register(my_constraint, my_transform)
or as a decorator on parameterized constraints:
@transform_to.register(MyConstraintClass)
def my_factory(constraint):
assert isinstance(constraint, MyConstraintClass)
return MyTransform(constraint.param1, constraint.param2)
You can create your own registry by creating a new ConstraintRegistry
object.
- class torch.distributions.constraint_registry.ConstraintRegistry[source]¶
Registry to link constraints to transforms.
- register(constraint, factory=None)[source]¶
Registers a
Constraint
subclass in this registry. Usage:@my_registry.register(MyConstraintClass) def construct_transform(constraint): assert isinstance(constraint, MyConstraint) return MyTransform(constraint.arg_constraints)
- Parameters
constraint (subclass of
Constraint
) – A subclass ofConstraint
, or a singleton object of the desired class.factory (Callable) – A callable that inputs a constraint object and returns a
Transform
object.