Probability distributions - torch.distributions¶
distributions package contains parameterizable probability distributions
and sampling functions.
Policy gradient methods can be implemented using the
log_prob() method, when the probability
density function is differentiable with respect to its parameters. A basic
method is the REINFORCE rule:
where \(\theta\) are the parameters, \(\alpha\) is the learning rate, \(r\) is the reward and \(p(a|\pi^\theta(s))\) is the probability of taking action \(a\) in state \(s\) given policy \(\pi^\theta\).
In practice we would sample an action from the output of a network, apply this
action in an environment, and then use
log_prob to construct an equivalent
loss function. Note that we use a negative because optimisers use gradient
descent, whilst the rule above assumes gradient ascent. With a categorical
policy, the code for implementing REINFORCE would be as follows:
probs = policy_network(state) # NOTE: this is equivalent to what used to be called multinomial m = Categorical(probs) action = m.sample() next_state, reward = env.step(action) loss = -m.log_prob(action) * reward loss.backward()
Distribution is the abstract base class for probability distributions.
Returns the log of the probability density/mass function evaluated at value.
Parameters: value (Tensor or Variable) –
Generates a single sample or single batch of samples if the distribution parameters are batched.
Generates n samples or n batches of samples if the distribution parameters are batched.
Creates a Bernoulli distribution parameterized by probs.
Samples are binary (0 or 1). They take the value 1 with probability p and 0 with probability 1 - p.
>>> m = Bernoulli(torch.Tensor([0.3])) >>> m.sample() # 30% chance 1; 70% chance 0 0.0 [torch.FloatTensor of size 1]
Parameters: probs (Tensor or Variable) – the probabilty of sampling 1
Creates a categorical distribution parameterized by probs.
It is equivalent to the distribution that
Samples are integers from 0 ... K-1 where K is probs.size(-1).
If probs is 1D with length-K, each element is the relative probability of sampling the class at that index.
If probs is 2D, it is treated as a batch of probability vectors.
>>> m = Categorical(torch.Tensor([ 0.25, 0.25, 0.25, 0.25 ])) >>> m.sample() # equal probability of 0, 1, 2, 3 3 [torch.LongTensor of size 1]
Parameters: probs (Tensor or Variable) – event probabilities
Creates a normal (also called Gaussian) distribution parameterized by mean and std.
>>> m = Normal(torch.Tensor([0.0]), torch.Tensor([1.0])) >>> m.sample() # normally distributed with mean=0 and stddev=1 0.1046 [torch.FloatTensor of size 1]