Shortcuts

Source code for torch.distributions.inverse_gamma

# mypy: allow-untyped-defs
import torch
from torch.distributions import constraints
from torch.distributions.gamma import Gamma
from torch.distributions.transformed_distribution import TransformedDistribution
from torch.distributions.transforms import PowerTransform


__all__ = ["InverseGamma"]


[docs]class InverseGamma(TransformedDistribution): r""" Creates an inverse gamma distribution parameterized by :attr:`concentration` and :attr:`rate` where:: X ~ Gamma(concentration, rate) Y = 1 / X ~ InverseGamma(concentration, rate) Example:: >>> # xdoctest: +IGNORE_WANT("non-deterinistic") >>> m = InverseGamma(torch.tensor([2.0]), torch.tensor([3.0])) >>> m.sample() tensor([ 1.2953]) Args: concentration (float or Tensor): shape parameter of the distribution (often referred to as alpha) rate (float or Tensor): rate = 1 / scale of the distribution (often referred to as beta) """ arg_constraints = { "concentration": constraints.positive, "rate": constraints.positive, } support = constraints.positive has_rsample = True def __init__(self, concentration, rate, validate_args=None): base_dist = Gamma(concentration, rate, validate_args=validate_args) neg_one = -base_dist.rate.new_ones(()) super().__init__( base_dist, PowerTransform(neg_one), validate_args=validate_args )
[docs] def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(InverseGamma, _instance) return super().expand(batch_shape, _instance=new)
@property def concentration(self): return self.base_dist.concentration @property def rate(self): return self.base_dist.rate @property def mean(self): result = self.rate / (self.concentration - 1) return torch.where(self.concentration > 1, result, torch.inf) @property def mode(self): return self.rate / (self.concentration + 1) @property def variance(self): result = self.rate.square() / ( (self.concentration - 1).square() * (self.concentration - 2) ) return torch.where(self.concentration > 2, result, torch.inf)
[docs] def entropy(self): return ( self.concentration + self.rate.log() + self.concentration.lgamma() - (1 + self.concentration) * self.concentration.digamma() )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources