torcheval.metrics.FrechetAudioDistance¶
-
class
torcheval.metrics.
FrechetAudioDistance
(preproc: Callable[[Tensor], Tensor], model: Module, embedding_dim: int, device: Optional[device] = None)[source]¶ Computes the Fréchet distance between predicted and target audio waveforms.
Original paper: https://arxiv.org/abs/1812.08466
Parameters: - preproc (Callable[[torch.Tensor], torch.Tensor]) – Callable for preprocessing waveforms prior to passing to model.
- model (torch.nn.Module) – Model for generating embeddings from preprocessed waveforms.
- embedding_dim (int) – Size of embedding.
- device (torch.device or None, optional) – Device where computations will be performed. If None, the default device will be used. (Default: None)
-
__init__
(preproc: Callable[[Tensor], Tensor], model: Module, embedding_dim: int, device: Optional[device] = None) None [source]¶ Initialize a metric object and its internal states.
Use
self._add_state()
to initialize state variables of your metric class. The state variables should be eithertorch.Tensor
, a list oftorch.Tensor
, or a dictionary withtorch.Tensor
as values
Methods
__init__
(preproc, model, embedding_dim[, device])Initialize a metric object and its internal states. compute
()Computes the Fréchet distance on the current set of internal states. load_state_dict
(state_dict[, strict])Loads metric state variables from state_dict. merge_state
(fads)Merges the states of other FrechetAudioDistance instances into those of the current instance. reset
()Reset the metric state variables to their default value. state_dict
()Save metric state variables in state_dict. to
(device, *args, **kwargs)Move tensors in metric state variables to device. update
(preds, targets)Update states with a batch of predicted and target waveforms. with_vggish
([device])Builds an instance of FrechetAudioDistance with TorchAudio's pretrained VGGish model. Attributes
device
The last input device of Metric.to()
.