Shortcuts

HuBERTPretrainModel

class torchaudio.models.HuBERTPretrainModel[source]

HuBERT model used for pretraining in HuBERT [Hsu et al., 2021].

Note

To build the model, please use one of the factory functions, hubert_pretrain_base(), hubert_pretrain_large() or hubert_pretrain_xlarge().

Parameters:
  • wav2vec2 (Wav2Vec2Model) – Wav2Vec2 encoder that generates the transformer outputs.

  • mask_generator (torch.nn.Module) – Mask generator that generates the mask for masked prediction during the training.

  • logit_generator (torch.nn.Module) – Logit generator that predicts the logits of the masked and unmasked inputs.

  • feature_grad_mult (float or None) – The factor to scale the convolutional feature extraction layer gradients by. If None, the gradients of feature extraction layers are not affected. The scale factor will not affect the forward pass.

forward

HuBERTPretrainModel.forward(waveforms: Tensor, labels: Tensor, audio_lengths: Optional[Tensor] = None) Tuple[Tensor, Optional[Tensor]][source]

Compute the sequence of probability distribution over labels.

Parameters:
  • waveforms (Tensor) – Audio tensor of dimension [batch, frames].

  • labels (Tensor) – Label for pre-training. A Tensor of dimension [batch, frames].

  • audio_lengths (Tensor or None, optional) – Indicates the valid length of each audio in the batch. Shape: [batch, ]. When the waveforms contains audios with different durations, by providing lengths argument, the model will compute the corresponding valid output lengths and apply proper mask in transformer attention layer. If None, it is assumed that all the audio in waveforms have valid length. Default: None.

Returns:

Tensor

The masked sequences of probability distribution (in logit). Shape: (masked_frames, num labels).

Tensor

The unmasked sequence of probability distribution (in logit). Shape: (unmasked_frames, num labels).

Tensor

The feature mean value for additional penalty loss. Shape: (1,).

Return type:

(Tensor, Tensor, Tensor)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources