[docs]classConstraint:""" Abstract base class for constraints. A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized. Attributes: is_discrete (bool): Whether constrained space is discrete. Defaults to False. event_dim (int): Number of rightmost dimensions that together define an event. The :meth:`check` method will remove this many dimensions when computing validity. """is_discrete=False# Default to continuous.event_dim=0# Default to univariate.
[docs]defcheck(self,value):""" Returns a byte tensor of ``sample_shape + batch_shape`` indicating whether each event in value satisfies this constraint. """raiseNotImplementedError
class_Dependent(Constraint):""" Placeholder for variables whose support depends on other variables. These variables obey no simple coordinate-wise constraints. Args: is_discrete (bool): Optional value of ``.is_discrete`` in case this can be computed statically. If not provided, access to the ``.is_discrete`` attribute will raise a NotImplementedError. event_dim (int): Optional value of ``.event_dim`` in case this can be computed statically. If not provided, access to the ``.event_dim`` attribute will raise a NotImplementedError. """def__init__(self,*,is_discrete=NotImplemented,event_dim=NotImplemented):self._is_discrete=is_discreteself._event_dim=event_dimsuper().__init__()@propertydefis_discrete(self):ifself._is_discreteisNotImplemented:raiseNotImplementedError(".is_discrete cannot be determined statically")returnself._is_discrete@propertydefevent_dim(self):ifself._event_dimisNotImplemented:raiseNotImplementedError(".event_dim cannot be determined statically")returnself._event_dimdef__call__(self,*,is_discrete=NotImplemented,event_dim=NotImplemented):""" Support for syntax to customize static attributes:: constraints.dependent(is_discrete=True, event_dim=1) """ifis_discreteisNotImplemented:is_discrete=self._is_discreteifevent_dimisNotImplemented:event_dim=self._event_dimreturn_Dependent(is_discrete=is_discrete,event_dim=event_dim)defcheck(self,x):raiseValueError("Cannot determine validity of dependent constraint")defis_dependent(constraint):returnisinstance(constraint,_Dependent)class_DependentProperty(property,_Dependent):""" Decorator that extends @property to act like a `Dependent` constraint when called on a class and act like a property when called on an object. Example:: class Uniform(Distribution): def __init__(self, low, high): self.low = low self.high = high @constraints.dependent_property(is_discrete=False, event_dim=0) def support(self): return constraints.interval(self.low, self.high) Args: fn (Callable): The function to be decorated. is_discrete (bool): Optional value of ``.is_discrete`` in case this can be computed statically. If not provided, access to the ``.is_discrete`` attribute will raise a NotImplementedError. event_dim (int): Optional value of ``.event_dim`` in case this can be computed statically. If not provided, access to the ``.event_dim`` attribute will raise a NotImplementedError. """def__init__(self,fn=None,*,is_discrete=NotImplemented,event_dim=NotImplemented):super().__init__(fn)self._is_discrete=is_discreteself._event_dim=event_dimdef__call__(self,fn):""" Support for syntax to customize static attributes:: @constraints.dependent_property(is_discrete=True, event_dim=1) def support(self): ... """return_DependentProperty(fn,is_discrete=self._is_discrete,event_dim=self._event_dim)class_IndependentConstraint(Constraint):""" Wraps a constraint by aggregating over ``reinterpreted_batch_ndims``-many dims in :meth:`check`, so that an event is valid only if all its independent entries are valid. """def__init__(self,base_constraint,reinterpreted_batch_ndims):assertisinstance(base_constraint,Constraint)assertisinstance(reinterpreted_batch_ndims,int)assertreinterpreted_batch_ndims>=0self.base_constraint=base_constraintself.reinterpreted_batch_ndims=reinterpreted_batch_ndimssuper().__init__()@propertydefis_discrete(self):returnself.base_constraint.is_discrete@propertydefevent_dim(self):returnself.base_constraint.event_dim+self.reinterpreted_batch_ndimsdefcheck(self,value):result=self.base_constraint.check(value)ifresult.dim()<self.reinterpreted_batch_ndims:expected=self.base_constraint.event_dim+self.reinterpreted_batch_ndimsraiseValueError(f"Expected value.dim() >= {expected} but got {value.dim()}")result=result.reshape(result.shape[:result.dim()-self.reinterpreted_batch_ndims]+(-1,))result=result.all(-1)returnresultdef__repr__(self):returnf"{self.__class__.__name__[1:]}({repr(self.base_constraint)}, {self.reinterpreted_batch_ndims})"class_Boolean(Constraint):""" Constrain to the two values `{0, 1}`. """is_discrete=Truedefcheck(self,value):return(value==0)|(value==1)class_OneHot(Constraint):""" Constrain to one-hot vectors. """is_discrete=Trueevent_dim=1defcheck(self,value):is_boolean=(value==0)|(value==1)is_normalized=value.sum(-1).eq(1)returnis_boolean.all(-1)&is_normalizedclass_IntegerInterval(Constraint):""" Constrain to an integer interval `[lower_bound, upper_bound]`. """is_discrete=Truedef__init__(self,lower_bound,upper_bound):self.lower_bound=lower_boundself.upper_bound=upper_boundsuper().__init__()defcheck(self,value):return((value%1==0)&(self.lower_bound<=value)&(value<=self.upper_bound))def__repr__(self):fmt_string=self.__class__.__name__[1:]fmt_string+=(f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})")returnfmt_stringclass_IntegerLessThan(Constraint):""" Constrain to an integer interval `(-inf, upper_bound]`. """is_discrete=Truedef__init__(self,upper_bound):self.upper_bound=upper_boundsuper().__init__()defcheck(self,value):return(value%1==0)&(value<=self.upper_bound)def__repr__(self):fmt_string=self.__class__.__name__[1:]fmt_string+=f"(upper_bound={self.upper_bound})"returnfmt_stringclass_IntegerGreaterThan(Constraint):""" Constrain to an integer interval `[lower_bound, inf)`. """is_discrete=Truedef__init__(self,lower_bound):self.lower_bound=lower_boundsuper().__init__()defcheck(self,value):return(value%1==0)&(value>=self.lower_bound)def__repr__(self):fmt_string=self.__class__.__name__[1:]fmt_string+=f"(lower_bound={self.lower_bound})"returnfmt_stringclass_Real(Constraint):""" Trivially constrain to the extended real line `[-inf, inf]`. """defcheck(self,value):returnvalue==value# False for NANs.class_GreaterThan(Constraint):""" Constrain to a real half line `(lower_bound, inf]`. """def__init__(self,lower_bound):self.lower_bound=lower_boundsuper().__init__()defcheck(self,value):returnself.lower_bound<valuedef__repr__(self):fmt_string=self.__class__.__name__[1:]fmt_string+=f"(lower_bound={self.lower_bound})"returnfmt_stringclass_GreaterThanEq(Constraint):""" Constrain to a real half line `[lower_bound, inf)`. """def__init__(self,lower_bound):self.lower_bound=lower_boundsuper().__init__()defcheck(self,value):returnself.lower_bound<=valuedef__repr__(self):fmt_string=self.__class__.__name__[1:]fmt_string+=f"(lower_bound={self.lower_bound})"returnfmt_stringclass_LessThan(Constraint):""" Constrain to a real half line `[-inf, upper_bound)`. """def__init__(self,upper_bound):self.upper_bound=upper_boundsuper().__init__()defcheck(self,value):returnvalue<self.upper_bounddef__repr__(self):fmt_string=self.__class__.__name__[1:]fmt_string+=f"(upper_bound={self.upper_bound})"returnfmt_stringclass_Interval(Constraint):""" Constrain to a real interval `[lower_bound, upper_bound]`. """def__init__(self,lower_bound,upper_bound):self.lower_bound=lower_boundself.upper_bound=upper_boundsuper().__init__()defcheck(self,value):return(self.lower_bound<=value)&(value<=self.upper_bound)def__repr__(self):fmt_string=self.__class__.__name__[1:]fmt_string+=(f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})")returnfmt_stringclass_HalfOpenInterval(Constraint):""" Constrain to a real interval `[lower_bound, upper_bound)`. """def__init__(self,lower_bound,upper_bound):self.lower_bound=lower_boundself.upper_bound=upper_boundsuper().__init__()defcheck(self,value):return(self.lower_bound<=value)&(value<self.upper_bound)def__repr__(self):fmt_string=self.__class__.__name__[1:]fmt_string+=(f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})")returnfmt_stringclass_Simplex(Constraint):""" Constrain to the unit simplex in the innermost (rightmost) dimension. Specifically: `x >= 0` and `x.sum(-1) == 1`. """event_dim=1defcheck(self,value):returntorch.all(value>=0,dim=-1)&((value.sum(-1)-1).abs()<1e-6)class_Multinomial(Constraint):""" Constrain to nonnegative integer values summing to at most an upper bound. Note due to limitations of the Multinomial distribution, this currently checks the weaker condition ``value.sum(-1) <= upper_bound``. In the future this may be strengthened to ``value.sum(-1) == upper_bound``. """is_discrete=Trueevent_dim=1def__init__(self,upper_bound):self.upper_bound=upper_bounddefcheck(self,x):return(x>=0).all(dim=-1)&(x.sum(dim=-1)<=self.upper_bound)class_LowerTriangular(Constraint):""" Constrain to lower-triangular square matrices. """event_dim=2defcheck(self,value):value_tril=value.tril()return(value_tril==value).view(value.shape[:-2]+(-1,)).min(-1)[0]class_LowerCholesky(Constraint):""" Constrain to lower-triangular square matrices with positive diagonals. """event_dim=2defcheck(self,value):value_tril=value.tril()lower_triangular=((value_tril==value).view(value.shape[:-2]+(-1,)).min(-1)[0])positive_diagonal=(value.diagonal(dim1=-2,dim2=-1)>0).min(-1)[0]returnlower_triangular&positive_diagonalclass_CorrCholesky(Constraint):""" Constrain to lower-triangular square matrices with positive diagonals and each row vector being of unit length. """event_dim=2defcheck(self,value):tol=(torch.finfo(value.dtype).eps*value.size(-1)*10)# 10 is an adjustable fudge factorrow_norm=torch.linalg.norm(value.detach(),dim=-1)unit_row_norm=(row_norm-1.0).abs().le(tol).all(dim=-1)return_LowerCholesky().check(value)&unit_row_normclass_Square(Constraint):""" Constrain to square matrices. """event_dim=2defcheck(self,value):returntorch.full(size=value.shape[:-2],fill_value=(value.shape[-2]==value.shape[-1]),dtype=torch.bool,device=value.device,)class_Symmetric(_Square):""" Constrain to Symmetric square matrices. """defcheck(self,value):square_check=super().check(value)ifnotsquare_check.all():returnsquare_checkreturntorch.isclose(value,value.mT,atol=1e-6).all(-2).all(-1)class_PositiveSemidefinite(_Symmetric):""" Constrain to positive-semidefinite matrices. """defcheck(self,value):sym_check=super().check(value)ifnotsym_check.all():returnsym_checkreturntorch.linalg.eigvalsh(value).ge(0).all(-1)class_PositiveDefinite(_Symmetric):""" Constrain to positive-definite matrices. """defcheck(self,value):sym_check=super().check(value)ifnotsym_check.all():returnsym_checkreturntorch.linalg.cholesky_ex(value).info.eq(0)class_Cat(Constraint):""" Constraint functor that applies a sequence of constraints `cseq` at the submatrices at dimension `dim`, each of size `lengths[dim]`, in a way compatible with :func:`torch.cat`. """def__init__(self,cseq,dim=0,lengths=None):assertall(isinstance(c,Constraint)forcincseq)self.cseq=list(cseq)iflengthsisNone:lengths=[1]*len(self.cseq)self.lengths=list(lengths)assertlen(self.lengths)==len(self.cseq)self.dim=dimsuper().__init__()@propertydefis_discrete(self):returnany(c.is_discreteforcinself.cseq)@propertydefevent_dim(self):returnmax(c.event_dimforcinself.cseq)defcheck(self,value):assert-value.dim()<=self.dim<value.dim()checks=[]start=0forconstr,lengthinzip(self.cseq,self.lengths):v=value.narrow(self.dim,start,length)checks.append(constr.check(v))start=start+length# avoid += for jit compatreturntorch.cat(checks,self.dim)class_Stack(Constraint):""" Constraint functor that applies a sequence of constraints `cseq` at the submatrices at dimension `dim`, in a way compatible with :func:`torch.stack`. """def__init__(self,cseq,dim=0):assertall(isinstance(c,Constraint)forcincseq)self.cseq=list(cseq)self.dim=dimsuper().__init__()@propertydefis_discrete(self):returnany(c.is_discreteforcinself.cseq)@propertydefevent_dim(self):dim=max(c.event_dimforcinself.cseq)ifself.dim+dim<0:dim+=1returndimdefcheck(self,value):assert-value.dim()<=self.dim<value.dim()vs=[value.select(self.dim,i)foriinrange(value.size(self.dim))]returntorch.stack([constr.check(v)forv,constrinzip(vs,self.cseq)],self.dim)# Public interface.dependent=_Dependent()dependent_property=_DependentPropertyindependent=_IndependentConstraintboolean=_Boolean()one_hot=_OneHot()nonnegative_integer=_IntegerGreaterThan(0)positive_integer=_IntegerGreaterThan(1)integer_interval=_IntegerIntervalreal=_Real()real_vector=independent(real,1)positive=_GreaterThan(0.0)nonnegative=_GreaterThanEq(0.0)greater_than=_GreaterThangreater_than_eq=_GreaterThanEqless_than=_LessThanmultinomial=_Multinomialunit_interval=_Interval(0.0,1.0)interval=_Intervalhalf_open_interval=_HalfOpenIntervalsimplex=_Simplex()lower_triangular=_LowerTriangular()lower_cholesky=_LowerCholesky()corr_cholesky=_CorrCholesky()square=_Square()symmetric=_Symmetric()positive_semidefinite=_PositiveSemidefinite()positive_definite=_PositiveDefinite()cat=_Catstack=_Stack
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.