[docs]classLogitRelaxedBernoulli(Distribution):r""" Creates a LogitRelaxedBernoulli distribution parameterized by :attr:`probs` or :attr:`logits` (but not both), which is the logit of a RelaxedBernoulli distribution. Samples are logits of values in (0, 1). See [1] for more details. Args: temperature (Tensor): relaxation temperature probs (Number, Tensor): the probability of sampling `1` logits (Number, Tensor): the log-odds of sampling `1` [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables (Maddison et al, 2017) [2] Categorical Reparametrization with Gumbel-Softmax (Jang et al, 2017) """arg_constraints={"probs":constraints.unit_interval,"logits":constraints.real}support=constraints.realdef__init__(self,temperature,probs=None,logits=None,validate_args=None):self.temperature=temperatureif(probsisNone)==(logitsisNone):raiseValueError("Either `probs` or `logits` must be specified, but not both.")ifprobsisnotNone:is_scalar=isinstance(probs,Number)(self.probs,)=broadcast_all(probs)else:is_scalar=isinstance(logits,Number)(self.logits,)=broadcast_all(logits)self._param=self.probsifprobsisnotNoneelseself.logitsifis_scalar:batch_shape=torch.Size()else:batch_shape=self._param.size()super().__init__(batch_shape,validate_args=validate_args)
[docs]classRelaxedBernoulli(TransformedDistribution):r""" Creates a RelaxedBernoulli distribution, parametrized by :attr:`temperature`, and either :attr:`probs` or :attr:`logits` (but not both). This is a relaxed version of the `Bernoulli` distribution, so the values are in (0, 1), and has reparametrizable samples. Example:: >>> # xdoctest: +IGNORE_WANT("non-deterinistic") >>> m = RelaxedBernoulli(torch.tensor([2.2]), ... torch.tensor([0.1, 0.2, 0.3, 0.99])) >>> m.sample() tensor([ 0.2951, 0.3442, 0.8918, 0.9021]) Args: temperature (Tensor): relaxation temperature probs (Number, Tensor): the probability of sampling `1` logits (Number, Tensor): the log-odds of sampling `1` """arg_constraints={"probs":constraints.unit_interval,"logits":constraints.real}support=constraints.unit_intervalhas_rsample=Truedef__init__(self,temperature,probs=None,logits=None,validate_args=None):base_dist=LogitRelaxedBernoulli(temperature,probs,logits)super().__init__(base_dist,SigmoidTransform(),validate_args=validate_args)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.