Ordinal¶
- class torchrl.modules.Ordinal(scores: Tensor)[source]¶
A discrete distribution for learning to sample from finite ordered sets.
It is defined in contrast with the Categorical distribution, which does not impose any notion of proximity or ordering over its support’s atoms. The Ordinal distribution explicitly encodes those concepts, which is useful for learning discrete sampling from continuous sets. See §5 of `Tang & Agrawal, 2020<https://arxiv.org/pdf/1901.10500.pdf>`_ for details.
Note
This class is mostly useful when you want to learn a distribution over a finite set which is obtained by discretising a continuous set.
- Parameters:
scores (torch.Tensor) – a tensor of shape […, N] where N is the size of the set which supports the distributions. Typically, the output of a neural network parametrising the distribution.
Examples
>>> num_atoms, num_samples = 5, 20 >>> mean = (num_atoms - 1) / 2 # Target mean for samples, centered around the middle atom >>> torch.manual_seed(42) >>> logits = torch.ones((num_atoms), requires_grad=True) >>> optimizer = torch.optim.Adam([logits], lr=0.1) >>> >>> # Perform optimisation loop to minimise deviation from `mean` >>> for _ in range(20): >>> sampler = Ordinal(scores=logits) >>> samples = sampler.sample((num_samples,)) >>> # Define loss to encourage samples around the mean by penalising deviation from mean >>> loss = torch.mean((samples - mean) ** 2 * sampler.log_prob(samples)) >>> loss.backward() >>> optimizer.step() >>> optimizer.zero_grad() >>> >>> sampler.probs tensor([0.0308, 0.1586, 0.4727, 0.2260, 0.1120], ...) >>> # Print histogram to observe sample distribution frequency across 5 bins (0, 1, 2, 3, and 4) >>> torch.histogram(sampler.sample((1000,)).reshape(-1).float(), bins=num_atoms) torch.return_types.histogram( hist=tensor([ 24., 158., 478., 228., 112.]), bin_edges=tensor([0.0000, 0.8000, 1.6000, 2.4000, 3.2000, 4.0000]))