# mypy: allow-untyped-defsfromfunctoolsimportupdate_wrapperfromnumbersimportNumberfromtypingimportAny,Dictimporttorchimporttorch.nn.functionalasFfromtorch.overridesimportis_tensor_likeeuler_constant=0.57721566490153286060# Euler Mascheroni Constant__all__=["broadcast_all","logits_to_probs","clamp_probs","probs_to_logits","lazy_property","tril_matrix_to_vec","vec_to_tril_matrix",]defbroadcast_all(*values):r""" Given a list of values (possibly containing numbers), returns a list where each value is broadcasted based on the following rules: - `torch.*Tensor` instances are broadcasted as per :ref:`_broadcasting-semantics`. - numbers.Number instances (scalars) are upcast to tensors having the same size and type as the first tensor passed to `values`. If all the values are scalars, then they are upcasted to scalar Tensors. Args: values (list of `numbers.Number`, `torch.*Tensor` or objects implementing __torch_function__) Raises: ValueError: if any of the values is not a `numbers.Number` instance, a `torch.*Tensor` instance, or an instance implementing __torch_function__ """ifnotall(is_tensor_like(v)orisinstance(v,Number)forvinvalues):raiseValueError("Input arguments must all be instances of numbers.Number, ""torch.Tensor or objects implementing __torch_function__.")ifnotall(is_tensor_like(v)forvinvalues):options:Dict[str,Any]=dict(dtype=torch.get_default_dtype())forvalueinvalues:ifisinstance(value,torch.Tensor):options=dict(dtype=value.dtype,device=value.device)breaknew_values=[vifis_tensor_like(v)elsetorch.tensor(v,**options)forvinvalues]returntorch.broadcast_tensors(*new_values)returntorch.broadcast_tensors(*values)def_standard_normal(shape,dtype,device):iftorch._C._get_tracing_state():# [JIT WORKAROUND] lack of support for .normal_()returntorch.normal(torch.zeros(shape,dtype=dtype,device=device),torch.ones(shape,dtype=dtype,device=device),)returntorch.empty(shape,dtype=dtype,device=device).normal_()def_sum_rightmost(value,dim):r""" Sum out ``dim`` many rightmost dimensions of a given tensor. Args: value (Tensor): A tensor of ``.dim()`` at least ``dim``. dim (int): The number of rightmost dims to sum out. """ifdim==0:returnvaluerequired_shape=value.shape[:-dim]+(-1,)returnvalue.reshape(required_shape).sum(-1)deflogits_to_probs(logits,is_binary=False):r""" Converts a tensor of logits into probabilities. Note that for the binary case, each value denotes log odds, whereas for the multi-dimensional case, the values along the last dimension denote the log probabilities (possibly unnormalized) of the events. """ifis_binary:returntorch.sigmoid(logits)returnF.softmax(logits,dim=-1)defclamp_probs(probs):"""Clamps the probabilities to be in the open interval `(0, 1)`. The probabilities would be clamped between `eps` and `1 - eps`, and `eps` would be the smallest representable positive number for the input data type. Args: probs (Tensor): A tensor of probabilities. Returns: Tensor: The clamped probabilities. Examples: >>> probs = torch.tensor([0.0, 0.5, 1.0]) >>> clamp_probs(probs) tensor([1.1921e-07, 5.0000e-01, 1.0000e+00]) >>> probs = torch.tensor([0.0, 0.5, 1.0], dtype=torch.float64) >>> clamp_probs(probs) tensor([2.2204e-16, 5.0000e-01, 1.0000e+00], dtype=torch.float64) """eps=torch.finfo(probs.dtype).epsreturnprobs.clamp(min=eps,max=1-eps)defprobs_to_logits(probs,is_binary=False):r""" Converts a tensor of probabilities into logits. For the binary case, this denotes the probability of occurrence of the event indexed by `1`. For the multi-dimensional case, the values along the last dimension denote the probabilities of occurrence of each of the events. """ps_clamped=clamp_probs(probs)ifis_binary:returntorch.log(ps_clamped)-torch.log1p(-ps_clamped)returntorch.log(ps_clamped)classlazy_property:r""" Used as a decorator for lazy loading of class attributes. This uses a non-data descriptor that calls the wrapped method to compute the property on first call; thereafter replacing the wrapped method into an instance attribute. """def__init__(self,wrapped):self.wrapped=wrappedupdate_wrapper(self,wrapped)# type:ignore[arg-type]def__get__(self,instance,obj_type=None):ifinstanceisNone:return_lazy_property_and_property(self.wrapped)withtorch.enable_grad():value=self.wrapped(instance)setattr(instance,self.wrapped.__name__,value)returnvalueclass_lazy_property_and_property(lazy_property,property):"""We want lazy properties to look like multiple things. * property when Sphinx autodoc looks * lazy_property when Distribution validate_args looks """def__init__(self,wrapped):property.__init__(self,wrapped)deftril_matrix_to_vec(mat:torch.Tensor,diag:int=0)->torch.Tensor:r""" Convert a `D x D` matrix or a batch of matrices into a (batched) vector which comprises of lower triangular elements from the matrix in row order. """n=mat.shape[-1]ifnottorch._C._get_tracing_state()and(diag<-nordiag>=n):raiseValueError(f"diag ({diag}) provided is outside [{-n}, {n-1}].")arange=torch.arange(n,device=mat.device)tril_mask=arange<arange.view(-1,1)+(diag+1)vec=mat[...,tril_mask]returnvecdefvec_to_tril_matrix(vec:torch.Tensor,diag:int=0)->torch.Tensor:r""" Convert a vector or a batch of vectors into a batched `D x D` lower triangular matrix containing elements from the vector in row order. """# +ve root of D**2 + (1+2*diag)*D - |diag| * (diag+1) - 2*vec.shape[-1] = 0n=(-(1+2*diag)+((1+2*diag)**2+8*vec.shape[-1]+4*abs(diag)*(diag+1))**0.5)/2eps=torch.finfo(vec.dtype).epsifnottorch._C._get_tracing_state()and(round(n)-n>eps):raiseValueError(f"The size of last dimension is {vec.shape[-1]} which cannot be expressed as "+"the lower triangular part of a square D x D matrix.")n=round(n.item())ifisinstance(n,torch.Tensor)elseround(n)mat=vec.new_zeros(vec.shape[:-1]+torch.Size((n,n)))arange=torch.arange(n,device=vec.device)tril_mask=arange<arange.view(-1,1)+(diag+1)mat[...,tril_mask]=vecreturnmat
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.