MaskedOneHotCategorical¶
- class torchrl.modules.MaskedOneHotCategorical(logits: Optional[Tensor] = None, probs: Optional[Tensor] = None, mask: Optional[Tensor] = None, indices: Optional[Tensor] = None, neg_inf: float = - inf, padding_value: Optional[int] = None, grad_method: ReparamGradientStrategy = ReparamGradientStrategy.PassThrough)[source]¶
MaskedCategorical distribution.
Reference: https://www.tensorflow.org/agents/api_docs/python/tf_agents/distributions/masked/MaskedCategorical
- Parameters:
logits (torch.Tensor) – event log probabilities (unnormalized)
probs (torch.Tensor) – event probabilities. If provided, the probabilities corresponding to masked items will be zeroed and the probability re-normalized along its last dimension.
- Keyword Arguments:
mask (torch.Tensor) – A boolean mask of the same shape as
logits
/probs
whereFalse
entries are the ones to be masked. Alternatively, ifsparse_mask
is True, it represents the list of valid indices in the distribution. Exclusive withindices
.indices (torch.Tensor) – A dense index tensor representing which actions must be taken into account. Exclusive with
mask
.neg_inf (float, optional) – The log-probability value allocated to invalid (out-of-mask) indices. Defaults to -inf.
padding_value – The padding value in then mask tensor when sparse_mask == True, the padding_value will be ignored.
grad_method (ReparamGradientStrategy, optional) –
strategy to gather reparameterized samples.
ReparamGradientStrategy.PassThrough
will compute the sample gradientsby using the softmax valued log-probability as a proxy to the samples gradients.
ReparamGradientStrategy.RelaxedOneHot
will usetorch.distributions.RelaxedOneHot
to sample from the distribution.torch.manual_seed (>>>) –
torch.randn (>>> logits =) –
torch.tensor (>>> mask =) –
MaskedOneHotCategorical (>>> dist =) –
dist.sample (>>> sample =) –
print (>>>) –
0], (tensor([[0, 0, 1,) – [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [1, 0, 0, 0], [0, 0, 1, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0]])
print –
-1.0831, (tensor([-1.1203, -1.0928, -1.0831, -1.1203, -1.1203, -1.0831, -1.1203,) – -1.1203, -1.1203])
torch.zeros_like (>>> sample_non_valid =) –
1 (>>> sample_non_valid[..., 1] =) –
print –
tensor ([-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf]) –
probabilities (>>> # with) –
torch.ones (>>> prob =) –
prob.sum() (>>> prob = prob /) –
torch.tensor –
MaskedOneHotCategorical –
torch.arange (>>> s =) –
torch.nn.functional.one_hot (>>> s =) –
print –
-2.1972, (tensor([ -inf, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972,) – -2.1972, -2.1972])
- log_prob(value: Tensor) Tensor [source]¶
Returns the log of the probability density/mass function evaluated at value.
- Parameters:
value (Tensor) –