Shortcuts

Source code for torch.distributions.fishersnedecor

from numbers import Number
import torch
from torch._six import nan
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.gamma import Gamma
from torch.distributions.utils import broadcast_all


[docs]class FisherSnedecor(Distribution): r""" Creates a Fisher-Snedecor distribution parameterized by :attr:`df1` and :attr:`df2`. Example:: >>> m = FisherSnedecor(torch.tensor([1.0]), torch.tensor([2.0])) >>> m.sample() # Fisher-Snedecor-distributed with df1=1 and df2=2 tensor([ 0.2453]) Args: df1 (float or Tensor): degrees of freedom parameter 1 df2 (float or Tensor): degrees of freedom parameter 2 """ arg_constraints = {'df1': constraints.positive, 'df2': constraints.positive} support = constraints.positive has_rsample = True def __init__(self, df1, df2, validate_args=None): self.df1, self.df2 = broadcast_all(df1, df2) self._gamma1 = Gamma(self.df1 * 0.5, self.df1) self._gamma2 = Gamma(self.df2 * 0.5, self.df2) if isinstance(df1, Number) and isinstance(df2, Number): batch_shape = torch.Size() else: batch_shape = self.df1.size() super(FisherSnedecor, self).__init__(batch_shape, validate_args=validate_args)
[docs] def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(FisherSnedecor, _instance) batch_shape = torch.Size(batch_shape) new.df1 = self.df1.expand(batch_shape) new.df2 = self.df2.expand(batch_shape) new._gamma1 = self._gamma1.expand(batch_shape) new._gamma2 = self._gamma2.expand(batch_shape) super(FisherSnedecor, new).__init__(batch_shape, validate_args=False) new._validate_args = self._validate_args return new
@property def mean(self): df2 = self.df2.clone(memory_format=torch.contiguous_format) df2[df2 <= 2] = nan return df2 / (df2 - 2) @property def variance(self): df2 = self.df2.clone(memory_format=torch.contiguous_format) df2[df2 <= 4] = nan return 2 * df2.pow(2) * (self.df1 + df2 - 2) / (self.df1 * (df2 - 2).pow(2) * (df2 - 4))
[docs] def rsample(self, sample_shape=torch.Size(())): shape = self._extended_shape(sample_shape) # X1 ~ Gamma(df1 / 2, 1 / df1), X2 ~ Gamma(df2 / 2, 1 / df2) # Y = df2 * df1 * X1 / (df1 * df2 * X2) = X1 / X2 ~ F(df1, df2) X1 = self._gamma1.rsample(sample_shape).view(shape) X2 = self._gamma2.rsample(sample_shape).view(shape) tiny = torch.finfo(X2.dtype).tiny X2.clamp_(min=tiny) Y = X1 / X2 Y.clamp_(min=tiny) return Y
[docs] def log_prob(self, value): if self._validate_args: self._validate_sample(value) ct1 = self.df1 * 0.5 ct2 = self.df2 * 0.5 ct3 = self.df1 / self.df2 t1 = (ct1 + ct2).lgamma() - ct1.lgamma() - ct2.lgamma() t2 = ct1 * ct3.log() + (ct1 - 1) * torch.log(value) t3 = (ct1 + ct2) * torch.log1p(ct3 * value) return t1 + t2 - t3

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources