[docs]classDistribution:r""" Distribution is the abstract base class for probability distributions. """has_rsample=Falsehas_enumerate_support=False_validate_args=__debug__
[docs]@staticmethoddefset_default_validate_args(value:bool)->None:""" Sets whether validation is enabled or disabled. The default behavior mimics Python's ``assert`` statement: validation is on by default, but is disabled if Python is run in optimized mode (via ``python -O``). Validation may be expensive, so you may want to disable it once a model is working. Args: value (bool): Whether to enable validation. """ifvaluenotin[True,False]:raiseValueErrorDistribution._validate_args=value
def__init__(self,batch_shape:torch.Size=torch.Size(),event_shape:torch.Size=torch.Size(),validate_args:Optional[bool]=None,):self._batch_shape=batch_shapeself._event_shape=event_shapeifvalidate_argsisnotNone:self._validate_args=validate_argsifself._validate_args:try:arg_constraints=self.arg_constraintsexceptNotImplementedError:arg_constraints={}warnings.warn(f"{self.__class__} does not define `arg_constraints`. "+"Please set `arg_constraints = {}` or initialize the distribution "+"with `validate_args=False` to turn off validation.")forparam,constraintinarg_constraints.items():ifconstraints.is_dependent(constraint):continue# skip constraints that cannot be checkedifparamnotinself.__dict__andisinstance(getattr(type(self),param),lazy_property):continue# skip checking lazily-constructed argsvalue=getattr(self,param)valid=constraint.check(value)ifnotvalid.all():raiseValueError(f"Expected parameter {param} "f"({type(value).__name__} of shape {tuple(value.shape)}) "f"of distribution {repr(self)} "f"to satisfy the constraint {repr(constraint)}, "f"but found invalid values:\n{value}")super().__init__()
[docs]defexpand(self,batch_shape:torch.Size,_instance=None):""" Returns a new distribution instance (or populates an existing instance provided by a derived class) with batch dimensions expanded to `batch_shape`. This method calls :class:`~torch.Tensor.expand` on the distribution's parameters. As such, this does not allocate new memory for the expanded distribution instance. Additionally, this does not repeat any args checking or parameter broadcasting in `__init__.py`, when an instance is first created. Args: batch_shape (torch.Size): the desired expanded size. _instance: new instance provided by subclasses that need to override `.expand`. Returns: New distribution instance with batch dimensions expanded to `batch_size`. """raiseNotImplementedError
@propertydefbatch_shape(self)->torch.Size:""" Returns the shape over which parameters are batched. """returnself._batch_shape@propertydefevent_shape(self)->torch.Size:""" Returns the shape of a single sample (without batching). """returnself._event_shape@propertydefarg_constraints(self)->Dict[str,constraints.Constraint]:""" Returns a dictionary from argument names to :class:`~torch.distributions.constraints.Constraint` objects that should be satisfied by each argument of this distribution. Args that are not tensors need not appear in this dict. """raiseNotImplementedError@propertydefsupport(self)->Optional[Any]:""" Returns a :class:`~torch.distributions.constraints.Constraint` object representing this distribution's support. """raiseNotImplementedError@propertydefmean(self)->torch.Tensor:""" Returns the mean of the distribution. """raiseNotImplementedError@propertydefmode(self)->torch.Tensor:""" Returns the mode of the distribution. """raiseNotImplementedError(f"{self.__class__} does not implement mode")@propertydefvariance(self)->torch.Tensor:""" Returns the variance of the distribution. """raiseNotImplementedError@propertydefstddev(self)->torch.Tensor:""" Returns the standard deviation of the distribution. """returnself.variance.sqrt()
[docs]defsample(self,sample_shape:torch.Size=torch.Size())->torch.Tensor:""" Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched. """withtorch.no_grad():returnself.rsample(sample_shape)
[docs]defrsample(self,sample_shape:torch.Size=torch.Size())->torch.Tensor:""" Generates a sample_shape shaped reparameterized sample or sample_shape shaped batch of reparameterized samples if the distribution parameters are batched. """raiseNotImplementedError
[docs]defsample_n(self,n:int)->torch.Tensor:""" Generates n samples or n batches of samples if the distribution parameters are batched. """warnings.warn("sample_n will be deprecated. Use .sample((n,)) instead",UserWarning)returnself.sample(torch.Size((n,)))
[docs]deflog_prob(self,value:torch.Tensor)->torch.Tensor:""" Returns the log of the probability density/mass function evaluated at `value`. Args: value (Tensor): """raiseNotImplementedError
[docs]defcdf(self,value:torch.Tensor)->torch.Tensor:""" Returns the cumulative density/mass function evaluated at `value`. Args: value (Tensor): """raiseNotImplementedError
[docs]deficdf(self,value:torch.Tensor)->torch.Tensor:""" Returns the inverse cumulative density/mass function evaluated at `value`. Args: value (Tensor): """raiseNotImplementedError
[docs]defenumerate_support(self,expand:bool=True)->torch.Tensor:""" Returns tensor containing all values supported by a discrete distribution. The result will enumerate over dimension 0, so the shape of the result will be `(cardinality,) + batch_shape + event_shape` (where `event_shape = ()` for univariate distributions). Note that this enumerates over all batched tensors in lock-step `[[0, 0], [1, 1], ...]`. With `expand=False`, enumeration happens along dim 0, but with the remaining batch dimensions being singleton dimensions, `[[0], [1], ..`. To iterate over the full Cartesian product use `itertools.product(m.enumerate_support())`. Args: expand (bool): whether to expand the support over the batch dims to match the distribution's `batch_shape`. Returns: Tensor iterating over dimension 0. """raiseNotImplementedError
[docs]defentropy(self)->torch.Tensor:""" Returns entropy of distribution, batched over batch_shape. Returns: Tensor of shape batch_shape. """raiseNotImplementedError
[docs]defperplexity(self)->torch.Tensor:""" Returns perplexity of distribution, batched over batch_shape. Returns: Tensor of shape batch_shape. """returntorch.exp(self.entropy())
def_extended_shape(self,sample_shape:_size=torch.Size())->Tuple[int,...]:""" Returns the size of the sample returned by the distribution, given a `sample_shape`. Note, that the batch and event shapes of a distribution instance are fixed at the time of construction. If this is empty, the returned shape is upcast to (1,). Args: sample_shape (torch.Size): the size of the sample to be drawn. """ifnotisinstance(sample_shape,torch.Size):sample_shape=torch.Size(sample_shape)returntorch.Size(sample_shape+self._batch_shape+self._event_shape)def_validate_sample(self,value:torch.Tensor)->None:""" Argument validation for distribution methods such as `log_prob`, `cdf` and `icdf`. The rightmost dimensions of a value to be scored via these methods must agree with the distribution's batch and event shapes. Args: value (Tensor): the tensor whose log probability is to be computed by the `log_prob` method. Raises ValueError: when the rightmost dimensions of `value` do not match the distribution's batch and event shapes. """ifnotisinstance(value,torch.Tensor):raiseValueError("The value argument to log_prob must be a Tensor")event_dim_start=len(value.size())-len(self._event_shape)ifvalue.size()[event_dim_start:]!=self._event_shape:raiseValueError(f"The right-most size of value must match event_shape: {value.size()} vs {self._event_shape}.")actual_shape=value.size()expected_shape=self._batch_shape+self._event_shapefori,jinzip(reversed(actual_shape),reversed(expected_shape)):ifi!=1andj!=1andi!=j:raiseValueError(f"Value is not broadcastable with batch_shape+event_shape: {actual_shape} vs {expected_shape}.")try:support=self.supportexceptNotImplementedError:warnings.warn(f"{self.__class__} does not define `support` to enable "+"sample validation. Please initialize the distribution with "+"`validate_args=False` to turn off validation.")returnassertsupportisnotNonevalid=support.check(value)ifnotvalid.all():raiseValueError("Expected value argument "f"({type(value).__name__} of shape {tuple(value.shape)}) "f"to be within the support ({repr(support)}) "f"of the distribution {repr(self)}, "f"but found invalid values:\n{value}")def_get_checked_instance(self,cls,_instance=None):if_instanceisNoneandtype(self).__init__!=cls.__init__:raiseNotImplementedError(f"Subclass {self.__class__.__name__} of {cls.__name__} that defines a custom __init__ method ""must also define a custom .expand() method.")returnself.__new__(type(self))if_instanceisNoneelse_instancedef__repr__(self)->str:param_names=[kfork,_inself.arg_constraints.items()ifkinself.__dict__]args_string=", ".join([f"{p}: {self.__dict__[p]ifself.__dict__[p].numel()==1elseself.__dict__[p].size()}"forpinparam_names])returnself.__class__.__name__+"("+args_string+")"
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.