Shortcuts

MaskedOneHotCategorical

class torchrl.modules.MaskedOneHotCategorical(logits: Optional[Tensor] = None, probs: Optional[Tensor] = None, mask: Optional[Tensor] = None, indices: Optional[Tensor] = None, neg_inf: float = - inf, padding_value: Optional[int] = None, grad_method: ReparamGradientStrategy = ReparamGradientStrategy.PassThrough)[source]

MaskedCategorical distribution.

Reference: https://www.tensorflow.org/agents/api_docs/python/tf_agents/distributions/masked/MaskedCategorical

Parameters:
  • logits (torch.Tensor) – event log probabilities (unnormalized)

  • probs (torch.Tensor) – event probabilities. If provided, the probabilities corresponding to masked items will be zeroed and the probability re-normalized along its last dimension.

Keyword Arguments:
  • mask (torch.Tensor) – A boolean mask of the same shape as logits/probs where False entries are the ones to be masked. Alternatively, if sparse_mask is True, it represents the list of valid indices in the distribution. Exclusive with indices.

  • indices (torch.Tensor) – A dense index tensor representing which actions must be taken into account. Exclusive with mask.

  • neg_inf (float, optional) – The log-probability value allocated to invalid (out-of-mask) indices. Defaults to -inf.

  • padding_value – The padding value in then mask tensor when sparse_mask == True, the padding_value will be ignored.

  • grad_method (ReparamGradientStrategy, optional) –

    strategy to gather reparameterized samples. ReparamGradientStrategy.PassThrough will compute the sample gradients

    by using the softmax valued log-probability as a proxy to the samples gradients.

    ReparamGradientStrategy.RelaxedOneHot will use torch.distributions.RelaxedOneHot to sample from the distribution.

Examples

>>> torch.manual_seed(0)
>>> logits = torch.randn(4) / 100  # almost equal probabilities
>>> mask = torch.tensor([True, False, True, True])
>>> dist = MaskedOneHotCategorical(logits=logits, mask=mask)
>>> sample = dist.sample((10,))
>>> print(sample)  # no `1` in the sample
tensor([[0, 0, 1, 0],
        [0, 0, 0, 1],
        [1, 0, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 1, 0],
        [1, 0, 0, 0],
        [0, 0, 1, 0],
        [1, 0, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 1, 0]])
>>> print(dist.log_prob(sample))
tensor([-1.1203, -1.0928, -1.0831, -1.1203, -1.1203, -1.0831, -1.1203, -1.0831,
        -1.1203, -1.1203])
>>> sample_non_valid = torch.zeros_like(sample)
>>> sample_non_valid[..., 1] = 1
>>> print(dist.log_prob(sample_non_valid))
tensor([-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf])
>>> # with probabilities
>>> prob = torch.ones(10)
>>> prob = prob / prob.sum()
>>> mask = torch.tensor([False] + 9 * [True])  # first outcome is masked
>>> dist = MaskedOneHotCategorical(probs=prob, mask=mask)
>>> s = torch.arange(10)
>>> s = torch.nn.functional.one_hot(s, 10)
>>> print(dist.log_prob(s))
tensor([   -inf, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972,
        -2.1972, -2.1972])
log_prob(value: Tensor) Tensor[source]

Returns the log of the probability density/mass function evaluated at value.

Parameters:

value (Tensor) –

property mode: Tensor

Returns the mode of the distribution.

rsample(sample_shape: Optional[Union[Size, Sequence]] = None) Tensor[source]

Generates a sample_shape shaped reparameterized sample or sample_shape shaped batch of reparameterized samples if the distribution parameters are batched.

sample(sample_shape: Optional[Union[Size, Sequence[int]]] = None) Tensor[source]

Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources