Source code for torch.distributions.constraint_registry
r"""PyTorch provides two global :class:`ConstraintRegistry` objects that link:class:`~torch.distributions.constraints.Constraint` objects to:class:`~torch.distributions.transforms.Transform` objects. These objects bothinput constraints and return transforms, but they have different guarantees onbijectivity.1. ``biject_to(constraint)`` looks up a bijective :class:`~torch.distributions.transforms.Transform` from ``constraints.real`` to the given ``constraint``. The returned transform is guaranteed to have ``.bijective = True`` and should implement ``.log_abs_det_jacobian()``.2. ``transform_to(constraint)`` looks up a not-necessarily bijective :class:`~torch.distributions.transforms.Transform` from ``constraints.real`` to the given ``constraint``. The returned transform is not guaranteed to implement ``.log_abs_det_jacobian()``.The ``transform_to()`` registry is useful for performing unconstrainedoptimization on constrained parameters of probability distributions, which areindicated by each distribution's ``.arg_constraints`` dict. These transforms oftenoverparameterize a space in order to avoid rotation; they are thus moresuitable for coordinate-wise optimization algorithms like Adam:: loc = torch.zeros(100, requires_grad=True) unconstrained = torch.zeros(100, requires_grad=True) scale = transform_to(Normal.arg_constraints['scale'])(unconstrained) loss = -Normal(loc, scale).log_prob(data).sum()The ``biject_to()`` registry is useful for Hamiltonian Monte Carlo, wheresamples from a probability distribution with constrained ``.support`` arepropagated in an unconstrained space, and algorithms are typically rotationinvariant.:: dist = Exponential(rate) unconstrained = torch.zeros(100, requires_grad=True) sample = biject_to(dist.support)(unconstrained) potential_energy = -dist.log_prob(sample).sum().. note:: An example where ``transform_to`` and ``biject_to`` differ is ``constraints.simplex``: ``transform_to(constraints.simplex)`` returns a :class:`~torch.distributions.transforms.SoftmaxTransform` that simply exponentiates and normalizes its inputs; this is a cheap and mostly coordinate-wise operation appropriate for algorithms like SVI. In contrast, ``biject_to(constraints.simplex)`` returns a :class:`~torch.distributions.transforms.StickBreakingTransform` that bijects its input down to a one-fewer-dimensional space; this a more expensive less numerically stable transform but is needed for algorithms like HMC.The ``biject_to`` and ``transform_to`` objects can be extended by user-definedconstraints and transforms using their ``.register()`` method either as afunction on singleton constraints:: transform_to.register(my_constraint, my_transform)or as a decorator on parameterized constraints:: @transform_to.register(MyConstraintClass) def my_factory(constraint): assert isinstance(constraint, MyConstraintClass) return MyTransform(constraint.param1, constraint.param2)You can create your own registry by creating a new :class:`ConstraintRegistry`object."""importnumbersfromtorch.distributionsimportconstraints,transforms__all__=["ConstraintRegistry","biject_to","transform_to",]
[docs]classConstraintRegistry:""" Registry to link constraints to transforms. """def__init__(self):self._registry={}super().__init__()
[docs]defregister(self,constraint,factory=None):""" Registers a :class:`~torch.distributions.constraints.Constraint` subclass in this registry. Usage:: @my_registry.register(MyConstraintClass) def construct_transform(constraint): assert isinstance(constraint, MyConstraint) return MyTransform(constraint.arg_constraints) Args: constraint (subclass of :class:`~torch.distributions.constraints.Constraint`): A subclass of :class:`~torch.distributions.constraints.Constraint`, or a singleton object of the desired class. factory (Callable): A callable that inputs a constraint object and returns a :class:`~torch.distributions.transforms.Transform` object. """# Support use as decorator.iffactoryisNone:returnlambdafactory:self.register(constraint,factory)# Support calling on singleton instances.ifisinstance(constraint,constraints.Constraint):constraint=type(constraint)ifnotisinstance(constraint,type)ornotissubclass(constraint,constraints.Constraint):raiseTypeError(f"Expected constraint to be either a Constraint subclass or instance, but got {constraint}")self._registry[constraint]=factoryreturnfactory
def__call__(self,constraint):""" Looks up a transform to constrained space, given a constraint object. Usage:: constraint = Normal.arg_constraints['scale'] scale = transform_to(constraint)(torch.zeros(1)) # constrained u = transform_to(constraint).inv(scale) # unconstrained Args: constraint (:class:`~torch.distributions.constraints.Constraint`): A constraint object. Returns: A :class:`~torch.distributions.transforms.Transform` object. Raises: `NotImplementedError` if no transform has been registered. """# Look up by Constraint subclass.try:factory=self._registry[type(constraint)]exceptKeyError:raiseNotImplementedError(f"Cannot transform {type(constraint).__name__} constraints")fromNonereturnfactory(constraint)
biject_to=ConstraintRegistry()transform_to=ConstraintRegistry()################################################################################# Registration Table################################################################################@biject_to.register(constraints.real)@transform_to.register(constraints.real)def_transform_to_real(constraint):returntransforms.identity_transform@biject_to.register(constraints.independent)def_biject_to_independent(constraint):base_transform=biject_to(constraint.base_constraint)returntransforms.IndependentTransform(base_transform,constraint.reinterpreted_batch_ndims)@transform_to.register(constraints.independent)def_transform_to_independent(constraint):base_transform=transform_to(constraint.base_constraint)returntransforms.IndependentTransform(base_transform,constraint.reinterpreted_batch_ndims)@biject_to.register(constraints.positive)@biject_to.register(constraints.nonnegative)@transform_to.register(constraints.positive)@transform_to.register(constraints.nonnegative)def_transform_to_positive(constraint):returntransforms.ExpTransform()@biject_to.register(constraints.greater_than)@biject_to.register(constraints.greater_than_eq)@transform_to.register(constraints.greater_than)@transform_to.register(constraints.greater_than_eq)def_transform_to_greater_than(constraint):returntransforms.ComposeTransform([transforms.ExpTransform(),transforms.AffineTransform(constraint.lower_bound,1),])@biject_to.register(constraints.less_than)@transform_to.register(constraints.less_than)def_transform_to_less_than(constraint):returntransforms.ComposeTransform([transforms.ExpTransform(),transforms.AffineTransform(constraint.upper_bound,-1),])@biject_to.register(constraints.interval)@biject_to.register(constraints.half_open_interval)@transform_to.register(constraints.interval)@transform_to.register(constraints.half_open_interval)def_transform_to_interval(constraint):# Handle the special case of the unit interval.lower_is_0=(isinstance(constraint.lower_bound,numbers.Number)andconstraint.lower_bound==0)upper_is_1=(isinstance(constraint.upper_bound,numbers.Number)andconstraint.upper_bound==1)iflower_is_0andupper_is_1:returntransforms.SigmoidTransform()loc=constraint.lower_boundscale=constraint.upper_bound-constraint.lower_boundreturntransforms.ComposeTransform([transforms.SigmoidTransform(),transforms.AffineTransform(loc,scale)])@biject_to.register(constraints.simplex)def_biject_to_simplex(constraint):returntransforms.StickBreakingTransform()@transform_to.register(constraints.simplex)def_transform_to_simplex(constraint):returntransforms.SoftmaxTransform()# TODO define a bijection for LowerCholeskyTransform@transform_to.register(constraints.lower_cholesky)def_transform_to_lower_cholesky(constraint):returntransforms.LowerCholeskyTransform()@transform_to.register(constraints.positive_definite)@transform_to.register(constraints.positive_semidefinite)def_transform_to_positive_definite(constraint):returntransforms.PositiveDefiniteTransform()@biject_to.register(constraints.corr_cholesky)@transform_to.register(constraints.corr_cholesky)def_transform_to_corr_cholesky(constraint):returntransforms.CorrCholeskyTransform()@biject_to.register(constraints.cat)def_biject_to_cat(constraint):returntransforms.CatTransform([biject_to(c)forcinconstraint.cseq],constraint.dim,constraint.lengths)@transform_to.register(constraints.cat)def_transform_to_cat(constraint):returntransforms.CatTransform([transform_to(c)forcinconstraint.cseq],constraint.dim,constraint.lengths)@biject_to.register(constraints.stack)def_biject_to_stack(constraint):returntransforms.StackTransform([biject_to(c)forcinconstraint.cseq],constraint.dim)@transform_to.register(constraints.stack)def_transform_to_stack(constraint):returntransforms.StackTransform([transform_to(c)forcinconstraint.cseq],constraint.dim)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.