Shortcuts

Source code for torcheval.metrics.image.fid

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import warnings
from importlib.util import find_spec
from typing import Any, Iterable, Optional, TypeVar, Union

import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torcheval.metrics.metric import Metric

if find_spec("torchvision") is not None:
    from torchvision import models

    _TORCHVISION_AVAILABLE = True
else:
    _TORCHVISION_AVAILABLE = False

TFrechetInceptionDistance = TypeVar("TFrechetInceptionDistance")

# pyre-ignore-all-errors[16]: Undefined attribute of metric states.


def _validate_torchvision_available() -> None:
    if not _TORCHVISION_AVAILABLE:
        raise RuntimeError(
            "You must have torchvision installed to use FID, please install torcheval[image]"
        )


class FIDInceptionV3(nn.Module):
    def __init__(
        self,
        weights: Optional[str] = "DEFAULT",
    ) -> None:
        """
        This class wraps the InceptionV3 model to compute FID.

        Args:
            weights Optional[str]: Defines the pre-trained weights to use.
        """
        super().__init__()
        # pyre-ignore
        self.model = models.inception_v3(weights=weights)
        # Do not want fc layer
        self.model.fc = nn.Identity()

    def forward(self, x: Tensor) -> Tensor:
        # Interpolating the input image tensors to be of size 299 x 299
        x = F.interpolate(x, size=(299, 299), mode="bilinear", align_corners=False)
        x = self.model(x)

        return x


[docs]class FrechetInceptionDistance(Metric[torch.Tensor]):
[docs] def __init__( self: TFrechetInceptionDistance, model: Optional[nn.Module] = None, feature_dim: int = 2048, device: Optional[torch.device] = None, ) -> None: """ Computes the Frechet Inception Distance (FID) between two distributions of images (real and generated). The original paper: https://arxiv.org/pdf/1706.08500.pdf Args: model (nn.Module): Module used to compute feature activations. If None, a default InceptionV3 model will be used. feature_dim (int): The number of features in the model's output, the default number is 2048 for default InceptionV3. device (torch.device): The device where the computations will be performed. If None, the default device will be used. """ _validate_torchvision_available() super().__init__(device=device) self._FID_parameter_check(model=model, feature_dim=feature_dim) if model is None: model = FIDInceptionV3() # Set the model and put it in evaluation mode self.model = model.to(device) self.model.eval() self.model.requires_grad_(False) # Initialize state variables used to compute FID self._add_state("real_sum", torch.zeros(feature_dim, device=device)) self._add_state( "real_cov_sum", torch.zeros((feature_dim, feature_dim), device=device) ) self._add_state("fake_sum", torch.zeros(feature_dim, device=device)) self._add_state( "fake_cov_sum", torch.zeros((feature_dim, feature_dim), device=device) ) self._add_state("num_real_images", torch.tensor(0, device=device).int()) self._add_state("num_fake_images", torch.tensor(0, device=device).int())
@torch.inference_mode() # pyre-ignore[14]: inconsistent override on *_:Any, **__:Any def update( self: TFrechetInceptionDistance, images: Tensor, is_real: bool ) -> TFrechetInceptionDistance: """ Update the states with a batch of real and fake images. Args: images (Tensor): A batch of images. is_real (Boolean): Denotes if images are real or not. """ self._FID_update_input_check(images=images, is_real=is_real) images = images.to(self.device) # Compute activations for images using the given model activations = self.model(images) batch_size = images.shape[0] # Update the state variables used to compute FID if is_real: self.num_real_images += batch_size self.real_sum += torch.sum(activations, dim=0) self.real_cov_sum += torch.matmul(activations.T, activations) else: self.num_fake_images += batch_size self.fake_sum += torch.sum(activations, dim=0) self.fake_cov_sum += torch.matmul(activations.T, activations) return self @torch.inference_mode() def merge_state( self: TFrechetInceptionDistance, metrics: Iterable[TFrechetInceptionDistance] ) -> TFrechetInceptionDistance: """ Merge the state of another FID instance into this instance. Args: metrics (Iterable[FID]): The other FID instance(s) whose state will be merged into this instance. """ for metric in metrics: self.real_sum += metric.real_sum.to(self.device) self.real_cov_sum += metric.real_cov_sum.to(self.device) self.fake_sum += metric.fake_sum.to(self.device) self.fake_cov_sum += metric.fake_cov_sum.to(self.device) self.num_real_images += metric.num_real_images.to(self.device) self.num_fake_images += metric.num_fake_images.to(self.device) return self @torch.inference_mode() def compute(self: TFrechetInceptionDistance) -> Tensor: """ Compute the FID. Returns: tensor: The FID. """ # If the user has not already updated with at lease one # image from each distribution, then we raise an Error. if (self.num_real_images < 2) or (self.num_fake_images < 2): warnings.warn( "Computing FID requires at least 2 real images and 2 fake images," f"but currently running with {self.num_real_images} real images and {self.num_fake_images} fake images." "Returning 0.0", RuntimeWarning, stacklevel=2, ) return torch.tensor(0.0) # Compute the mean activations for each distribution real_mean = (self.real_sum / self.num_real_images).unsqueeze(0) fake_mean = (self.fake_sum / self.num_fake_images).unsqueeze(0) # Compute the covariance matrices for each distribution real_cov_num = self.real_cov_sum - self.num_real_images * torch.matmul( real_mean.T, real_mean ) real_cov = real_cov_num / (self.num_real_images - 1) fake_cov_num = self.fake_cov_sum - self.num_fake_images * torch.matmul( fake_mean.T, fake_mean ) fake_cov = fake_cov_num / (self.num_fake_images - 1) # Compute the Frechet Distance between the distributions fid = self._calculate_frechet_distance( real_mean.squeeze(), real_cov, fake_mean.squeeze(), fake_cov ) return fid def _calculate_frechet_distance( self: TFrechetInceptionDistance, mu1: Tensor, sigma1: Tensor, mu2: Tensor, sigma2: Tensor, ) -> Tensor: """ Calculate the Frechet Distance between two multivariate Gaussian distributions. Args: mu1 (Tensor): The mean of the first distribution. sigma1 (Tensor): The covariance matrix of the first distribution. mu2 (Tensor): The mean of the second distribution. sigma2 (Tensor): The covariance matrix of the second distribution. Returns: tensor: The Frechet Distance between the two distributions. """ # Compute the squared distance between the means mean_diff = mu1 - mu2 mean_diff_squared = mean_diff.square().sum(dim=-1) # Calculate the sum of the traces of both covariance matrices trace_sum = sigma1.trace() + sigma2.trace() # Compute the eigenvalues of the matrix product of the real and fake covariance matrices sigma_mm = torch.matmul(sigma1, sigma2) eigenvals = torch.linalg.eigvals(sigma_mm) # Take the square root of each eigenvalue and take its sum sqrt_eigenvals_sum = eigenvals.sqrt().real.sum(dim=-1) # Calculate the FID using the squared distance between the means, # the sum of the traces of the covariance matrices, and the sum of the square roots of the eigenvalues fid = mean_diff_squared + trace_sum - 2 * sqrt_eigenvals_sum return fid def _FID_parameter_check( self: TFrechetInceptionDistance, model: Optional[nn.Module], feature_dim: int, ) -> None: # Whatever the model, the feature_dim needs to be set if feature_dim is None or feature_dim <= 0: raise RuntimeError("feature_dim has to be a positive integer") if model is None and feature_dim != 2048: raise RuntimeError( "When the default Inception v3 model is used, feature_dim needs to be set to 2048" ) def _FID_update_input_check( self: TFrechetInceptionDistance, images: torch.Tensor, is_real: bool ) -> None: if not torch.is_tensor(images): raise ValueError(f"Expected tensor as input, but got {type(images)}.") if images.dim() != 4: raise ValueError( f"Expected 4D tensor as input. But input has {images.dim()} dimenstions." ) if images.size()[1] != 3: raise ValueError(f"Expected 3 channels as input. Got {images.size()[1]}.") if type(is_real) != bool: raise ValueError( f"Expected 'real' to be of type bool but got {type(is_real)}.", ) if isinstance(self.model, FIDInceptionV3): if images.dtype != torch.float32: raise ValueError( f"When default inception-v3 model is used, images expected to be `torch.float32`, but got {images.dtype}." ) if images.min() < 0 or images.max() > 1: raise ValueError( "When default inception-v3 model is used, images are expected to be in the [0, 1] interval" ) def to( self: TFrechetInceptionDistance, device: Union[str, torch.device], *args: Any, **kwargs: Any, ) -> TFrechetInceptionDistance: super().to(device=device) self.model.to(self.device) return self

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources