Probability distributions - torch.distributions¶
The distributions
package contains parameterizable probability distributions
and sampling functions.
Policy gradient methods can be implemented using the
log_prob()
method, when the probability
density function is differentiable with respect to its parameters. A basic
method is the REINFORCE rule:
where θ are the parameters, α is the learning rate, r is the reward and p(a|πθ(s)) is the probability of taking action a in state s given policy πθ.
In practice we would sample an action from the output of a network, apply this
action in an environment, and then use log_prob
to construct an equivalent
loss function. Note that we use a negative because optimisers use gradient
descent, whilst the rule above assumes gradient ascent. With a categorical
policy, the code for implementing REINFORCE would be as follows:
probs = policy_network(state)
# NOTE: this is equivalent to what used to be called multinomial
m = Categorical(probs)
action = m.sample()
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward
loss.backward()
Distribution¶
-
class
torch.distributions.
Distribution
[source]¶ Distribution is the abstract base class for probability distributions.
-
log_prob
(value)[source]¶ Returns the log of the probability density/mass function evaluated at value.
Parameters: value (Tensor or Variable) –
-
Bernoulli¶
-
class
torch.distributions.
Bernoulli
(probs)[source]¶ Creates a Bernoulli distribution parameterized by probs.
Samples are binary (0 or 1). They take the value 1 with probability p and 0 with probability 1 - p.
Example:
>>> m = Bernoulli(torch.Tensor([0.3])) >>> m.sample() # 30% chance 1; 70% chance 0 0.0 [torch.FloatTensor of size 1]
Parameters: probs (Tensor or Variable) – the probabilty of sampling 1
Categorical¶
-
class
torch.distributions.
Categorical
(probs)[source]¶ Creates a categorical distribution parameterized by probs.
Note
It is equivalent to the distribution that
multinomial()
samples from.Samples are integers from 0 ... K-1 where K is probs.size(-1).
If probs is 1D with length-K, each element is the relative probability of sampling the class at that index.
If probs is 2D, it is treated as a batch of probability vectors.
See also:
torch.multinomial()
Example:
>>> m = Categorical(torch.Tensor([ 0.25, 0.25, 0.25, 0.25 ])) >>> m.sample() # equal probability of 0, 1, 2, 3 3 [torch.LongTensor of size 1]
Parameters: probs (Tensor or Variable) – event probabilities
Normal¶
-
class
torch.distributions.
Normal
(mean, std)[source]¶ Creates a normal (also called Gaussian) distribution parameterized by mean and std.
Example:
>>> m = Normal(torch.Tensor([0.0]), torch.Tensor([1.0])) >>> m.sample() # normally distributed with mean=0 and stddev=1 0.1046 [torch.FloatTensor of size 1]
Parameters: