Shortcuts

Source code for torchaudio.models._hdemucs

# *****************************************************************************
# MIT License
#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# *****************************************************************************


import math
import typing as tp
from typing import Any, Dict, List, Optional

import torch
from torch import nn
from torch.nn import functional as F


class _ScaledEmbedding(torch.nn.Module):
    r"""Make continuous embeddings and boost learning rate

    Args:
        num_embeddings (int): number of embeddings
        embedding_dim (int): embedding dimensions
        scale (float, optional): amount to scale learning rate (Default: 10.0)
        smooth (bool, optional): choose to apply smoothing (Default: ``False``)
    """

    def __init__(self, num_embeddings: int, embedding_dim: int, scale: float = 10.0, smooth: bool = False):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        if smooth:
            weight = torch.cumsum(self.embedding.weight.data, dim=0)
            # when summing gaussian, scale raises as sqrt(n), so we normalize by that.
            weight = weight / torch.arange(1, num_embeddings + 1).sqrt()[:, None]
            self.embedding.weight.data[:] = weight
        self.embedding.weight.data /= scale
        self.scale = scale

    @property
    def weight(self) -> torch.Tensor:
        return self.embedding.weight * self.scale

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        r"""Forward pass for embedding with scale.
        Args:
            x (torch.Tensor): input tensor of shape `(num_embeddings)`

        Returns:
            (Tensor):
                Embedding output of shape `(num_embeddings, embedding_dim)`
        """
        out = self.embedding(x) * self.scale
        return out


class _HEncLayer(torch.nn.Module):

    r"""Encoder layer. This used both by the time and the frequency branch.
    Args:
        chin (int): number of input channels.
        chout (int): number of output channels.
        kernel_size (int, optional): Kernel size for encoder (Default: 8)
        stride (int, optional): Stride for encoder layer (Default: 4)
        norm_groups (int, optional): number of groups for group norm. (Default: 4)
        empty (bool, optional): used to make a layer with just the first conv. this is used
            before merging the time and freq. branches. (Default: ``False``)
        freq (bool, optional): boolean for whether conv layer is for frequency domain (Default: ``True``)
        norm_type (string, optional): Norm type, either ``group_norm `` or ``none`` (Default: ``group_norm``)
        context (int, optional): context size for the 1x1 conv. (Default: 0)
        dconv_kw (Dict[str, Any] or None, optional): dictionary of kwargs for the DConv class. (Default: ``None``)
        pad (bool, optional): true to pad the input. Padding is done so that the output size is
            always the input size / stride. (Default: ``True``)
    """

    def __init__(
        self,
        chin: int,
        chout: int,
        kernel_size: int = 8,
        stride: int = 4,
        norm_groups: int = 4,
        empty: bool = False,
        freq: bool = True,
        norm_type: str = "group_norm",
        context: int = 0,
        dconv_kw: Optional[Dict[str, Any]] = None,
        pad: bool = True,
    ):
        super().__init__()
        if dconv_kw is None:
            dconv_kw = {}
        norm_fn = lambda d: nn.Identity()  # noqa
        if norm_type == "group_norm":
            norm_fn = lambda d: nn.GroupNorm(norm_groups, d)  # noqa
        pad_val = kernel_size // 4 if pad else 0
        klass = nn.Conv1d
        self.freq = freq
        self.kernel_size = kernel_size
        self.stride = stride
        self.empty = empty
        self.pad = pad_val
        if freq:
            kernel_size = [kernel_size, 1]
            stride = [stride, 1]
            pad_val = [pad_val, 0]
            klass = nn.Conv2d
        self.conv = klass(chin, chout, kernel_size, stride, pad_val)
        self.norm1 = norm_fn(chout)

        if self.empty:
            self.rewrite = nn.Identity()
            self.norm2 = nn.Identity()
            self.dconv = nn.Identity()
        else:
            self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
            self.norm2 = norm_fn(2 * chout)
            self.dconv = _DConv(chout, **dconv_kw)

    def forward(self, x: torch.Tensor, inject: Optional[torch.Tensor] = None) -> torch.Tensor:
        r"""Forward pass for encoding layer.

        Size depends on whether frequency or time

        Args:
            x (torch.Tensor): tensor input of shape `(B, C, F, T)` for frequency and shape
                `(B, C, T)` for time
            inject (torch.Tensor, optional): on last layer, combine frequency and time branches through inject param,
                same shape as x (default: ``None``)

        Returns:
            Tensor
                output tensor after encoder layer of shape `(B, C, F / stride, T)` for frequency
                    and shape `(B, C, ceil(T / stride))` for time
        """

        if not self.freq and x.dim() == 4:
            B, C, Fr, T = x.shape
            x = x.view(B, -1, T)

        if not self.freq:
            le = x.shape[-1]
            if not le % self.stride == 0:
                x = F.pad(x, (0, self.stride - (le % self.stride)))
        y = self.conv(x)
        if self.empty:
            return y
        if inject is not None:
            if inject.shape[-1] != y.shape[-1]:
                raise ValueError("Injection shapes do not align")
            if inject.dim() == 3 and y.dim() == 4:
                inject = inject[:, :, None]
            y = y + inject
        y = F.gelu(self.norm1(y))
        if self.freq:
            B, C, Fr, T = y.shape
            y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
            y = self.dconv(y)
            y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
        else:
            y = self.dconv(y)
        z = self.norm2(self.rewrite(y))
        z = F.glu(z, dim=1)
        return z


class _HDecLayer(torch.nn.Module):
    r"""Decoder layer. This used both by the time and the frequency branches.
    Args:
        chin (int): number of input channels.
        chout (int): number of output channels.
        last (bool, optional): whether current layer is final layer (Default: ``False``)
        kernel_size (int, optional): Kernel size for encoder (Default: 8)
        stride (int): Stride for encoder layer (Default: 4)
        norm_groups (int, optional): number of groups for group norm. (Default: 1)
        empty (bool, optional): used to make a layer with just the first conv. this is used
            before merging the time and freq. branches. (Default: ``False``)
        freq (bool, optional): boolean for whether conv layer is for frequency (Default: ``True``)
        norm_type (str, optional): Norm type, either ``group_norm `` or ``none`` (Default: ``group_norm``)
        context (int, optional): context size for the 1x1 conv. (Default: 1)
        dconv_kw (Dict[str, Any] or None, optional): dictionary of kwargs for the DConv class. (Default: ``None``)
        pad (bool, optional): true to pad the input. Padding is done so that the output size is
            always the input size / stride. (Default: ``True``)
    """

    def __init__(
        self,
        chin: int,
        chout: int,
        last: bool = False,
        kernel_size: int = 8,
        stride: int = 4,
        norm_groups: int = 1,
        empty: bool = False,
        freq: bool = True,
        norm_type: str = "group_norm",
        context: int = 1,
        dconv_kw: Optional[Dict[str, Any]] = None,
        pad: bool = True,
    ):
        super().__init__()
        if dconv_kw is None:
            dconv_kw = {}
        norm_fn = lambda d: nn.Identity()  # noqa
        if norm_type == "group_norm":
            norm_fn = lambda d: nn.GroupNorm(norm_groups, d)  # noqa
        if pad:
            if (kernel_size - stride) % 2 != 0:
                raise ValueError("Kernel size and stride do not align")
            pad = (kernel_size - stride) // 2
        else:
            pad = 0
        self.pad = pad
        self.last = last
        self.freq = freq
        self.chin = chin
        self.empty = empty
        self.stride = stride
        self.kernel_size = kernel_size
        klass = nn.Conv1d
        klass_tr = nn.ConvTranspose1d
        if freq:
            kernel_size = [kernel_size, 1]
            stride = [stride, 1]
            klass = nn.Conv2d
            klass_tr = nn.ConvTranspose2d
        self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
        self.norm2 = norm_fn(chout)
        if self.empty:
            self.rewrite = nn.Identity()
            self.norm1 = nn.Identity()
        else:
            self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
            self.norm1 = norm_fn(2 * chin)

    def forward(self, x: torch.Tensor, skip: Optional[torch.Tensor], length):
        r"""Forward pass for decoding layer.

        Size depends on whether frequency or time

        Args:
            x (torch.Tensor): tensor input of shape `(B, C, F, T)` for frequency and shape
                `(B, C, T)` for time
            skip (torch.Tensor, optional): on first layer, separate frequency and time branches using param
                (default: ``None``)
            length (int): Size of tensor for output

        Returns:
            (Tensor, Tensor):
                Tensor
                    output tensor after decoder layer of shape `(B, C, F * stride, T)` for frequency domain except last
                        frequency layer shape is `(B, C, kernel_size, T)`. Shape is `(B, C, stride * T)`
                        for time domain.
                Tensor
                    contains the output just before final transposed convolution, which is used when the
                        freq. and time branch separate. Otherwise, does not matter. Shape is
                        `(B, C, F, T)` for frequency and `(B, C, T)` for time.
        """
        if self.freq and x.dim() == 3:
            B, C, T = x.shape
            x = x.view(B, self.chin, -1, T)

        if not self.empty:
            x = x + skip
            y = F.glu(self.norm1(self.rewrite(x)), dim=1)
        else:
            y = x
            if skip is not None:
                raise ValueError("Skip must be none when empty is true.")

        z = self.norm2(self.conv_tr(y))
        if self.freq:
            if self.pad:
                z = z[..., self.pad : -self.pad, :]
        else:
            z = z[..., self.pad : self.pad + length]
            if z.shape[-1] != length:
                raise ValueError("Last index of z must be equal to length")
        if not self.last:
            z = F.gelu(z)

        return z, y


[docs]class HDemucs(torch.nn.Module): r"""Hybrid Demucs model from *Hybrid Spectrogram and Waveform Source Separation* :cite:`defossez2021hybrid`. See Also: * :class:`torchaudio.pipelines.SourceSeparationBundle`: Source separation pipeline with pre-trained models. Args: sources (List[str]): list of source names. List can contain the following source options: [``"bass"``, ``"drums"``, ``"other"``, ``"mixture"``, ``"vocals"``]. audio_channels (int, optional): input/output audio channels. (Default: 2) channels (int, optional): initial number of hidden channels. (Default: 48) growth (int, optional): increase the number of hidden channels by this factor at each layer. (Default: 2) nfft (int, optional): number of fft bins. Note that changing this requires careful computation of various shape parameters and will not work out of the box for hybrid models. (Default: 4096) depth (int, optional): number of layers in encoder and decoder (Default: 6) freq_emb (float, optional): add frequency embedding after the first frequency layer if > 0, the actual value controls the weight of the embedding. (Default: 0.2) emb_scale (int, optional): equivalent to scaling the embedding learning rate (Default: 10) emb_smooth (bool, optional): initialize the embedding with a smooth one (with respect to frequencies). (Default: ``True``) kernel_size (int, optional): kernel_size for encoder and decoder layers. (Default: 8) time_stride (int, optional): stride for the final time layer, after the merge. (Default: 2) stride (int, optional): stride for encoder and decoder layers. (Default: 4) context (int, optional): context for 1x1 conv in the decoder. (Default: 4) context_enc (int, optional): context for 1x1 conv in the encoder. (Default: 0) norm_starts (int, optional): layer at which group norm starts being used. decoder layers are numbered in reverse order. (Default: 4) norm_groups (int, optional): number of groups for group norm. (Default: 4) dconv_depth (int, optional): depth of residual DConv branch. (Default: 2) dconv_comp (int, optional): compression of DConv branch. (Default: 4) dconv_attn (int, optional): adds attention layers in DConv branch starting at this layer. (Default: 4) dconv_lstm (int, optional): adds a LSTM layer in DConv branch starting at this layer. (Default: 4) dconv_init (float, optional): initial scale for the DConv branch LayerScale. (Default: 1e-4) """ def __init__( self, sources: List[str], audio_channels: int = 2, channels: int = 48, growth: int = 2, nfft: int = 4096, depth: int = 6, freq_emb: float = 0.2, emb_scale: int = 10, emb_smooth: bool = True, kernel_size: int = 8, time_stride: int = 2, stride: int = 4, context: int = 1, context_enc: int = 0, norm_starts: int = 4, norm_groups: int = 4, dconv_depth: int = 2, dconv_comp: int = 4, dconv_attn: int = 4, dconv_lstm: int = 4, dconv_init: float = 1e-4, ): super().__init__() self.depth = depth self.nfft = nfft self.audio_channels = audio_channels self.sources = sources self.kernel_size = kernel_size self.context = context self.stride = stride self.channels = channels self.hop_length = self.nfft // 4 self.freq_emb = None self.freq_encoder = nn.ModuleList() self.freq_decoder = nn.ModuleList() self.time_encoder = nn.ModuleList() self.time_decoder = nn.ModuleList() chin = audio_channels chin_z = chin * 2 # number of channels for the freq branch chout = channels chout_z = channels freqs = self.nfft // 2 for index in range(self.depth): lstm = index >= dconv_lstm attn = index >= dconv_attn norm_type = "group_norm" if index >= norm_starts else "none" freq = freqs > 1 stri = stride ker = kernel_size if not freq: if freqs != 1: raise ValueError("When freq is false, freqs must be 1.") ker = time_stride * 2 stri = time_stride pad = True last_freq = False if freq and freqs <= kernel_size: ker = freqs pad = False last_freq = True kw = { "kernel_size": ker, "stride": stri, "freq": freq, "pad": pad, "norm_type": norm_type, "norm_groups": norm_groups, "dconv_kw": { "lstm": lstm, "attn": attn, "depth": dconv_depth, "compress": dconv_comp, "init": dconv_init, }, } kwt = dict(kw) kwt["freq"] = 0 kwt["kernel_size"] = kernel_size kwt["stride"] = stride kwt["pad"] = True kw_dec = dict(kw) if last_freq: chout_z = max(chout, chout_z) chout = chout_z enc = _HEncLayer(chin_z, chout_z, context=context_enc, **kw) if freq: if last_freq is True and nfft == 2048: kwt["stride"] = 2 kwt["kernel_size"] = 4 tenc = _HEncLayer(chin, chout, context=context_enc, empty=last_freq, **kwt) self.time_encoder.append(tenc) self.freq_encoder.append(enc) if index == 0: chin = self.audio_channels * len(self.sources) chin_z = chin * 2 dec = _HDecLayer(chout_z, chin_z, last=index == 0, context=context, **kw_dec) if freq: tdec = _HDecLayer(chout, chin, empty=last_freq, last=index == 0, context=context, **kwt) self.time_decoder.insert(0, tdec) self.freq_decoder.insert(0, dec) chin = chout chin_z = chout_z chout = int(growth * chout) chout_z = int(growth * chout_z) if freq: if freqs <= kernel_size: freqs = 1 else: freqs //= stride if index == 0 and freq_emb: self.freq_emb = _ScaledEmbedding(freqs, chin_z, smooth=emb_smooth, scale=emb_scale) self.freq_emb_scale = freq_emb _rescale_module(self) def _spec(self, x): hl = self.hop_length nfft = self.nfft x0 = x # noqa # We re-pad the signal in order to keep the property # that the size of the output is exactly the size of the input # divided by the stride (here hop_length), when divisible. # This is achieved by padding by 1/4th of the kernel size (here nfft). # which is not supported by torch.stft. # Having all convolution operations follow this convention allow to easily # align the time and frequency branches later on. if hl != nfft // 4: raise ValueError("Hop length must be nfft // 4") le = int(math.ceil(x.shape[-1] / hl)) pad = hl // 2 * 3 x = self._pad1d(x, pad, pad + le * hl - x.shape[-1], mode="reflect") z = _spectro(x, nfft, hl)[..., :-1, :] if z.shape[-1] != le + 4: raise ValueError("Spectrogram's last dimension must be 4 + input size divided by stride") z = z[..., 2 : 2 + le] return z def _ispec(self, z, length=None): hl = self.hop_length z = F.pad(z, [0, 0, 0, 1]) z = F.pad(z, [2, 2]) pad = hl // 2 * 3 le = hl * int(math.ceil(length / hl)) + 2 * pad x = _ispectro(z, hl, length=le) x = x[..., pad : pad + length] return x def _pad1d(self, x: torch.Tensor, padding_left: int, padding_right: int, mode: str = "zero", value: float = 0.0): """Wrapper around F.pad, in order for reflect padding when num_frames is shorter than max_pad. Add extra zero padding around in order for padding to not break.""" length = x.shape[-1] if mode == "reflect": max_pad = max(padding_left, padding_right) if length <= max_pad: x = F.pad(x, (0, max_pad - length + 1)) return F.pad(x, (padding_left, padding_right), mode, value) def _magnitude(self, z): # move the complex dimension to the channel one. B, C, Fr, T = z.shape m = torch.view_as_real(z).permute(0, 1, 4, 2, 3) m = m.reshape(B, C * 2, Fr, T) return m def _mask(self, m): # `m` is a full spectrogram and `z` is ignored. B, S, C, Fr, T = m.shape out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3) out = torch.view_as_complex(out.contiguous()) return out
[docs] def forward(self, input: torch.Tensor): r"""HDemucs forward call Args: input (torch.Tensor): input mixed tensor of shape `(batch_size, channel, num_frames)` Returns: Tensor output tensor split into sources of shape `(batch_size, num_sources, channel, num_frames)` """ if input.ndim != 3: raise ValueError(f"Expected 3D tensor with dimensions (batch, channel, frames). Found: {input.shape}") if input.shape[1] != self.audio_channels: raise ValueError( f"The channel dimension of input Tensor must match `audio_channels` of HDemucs model. " f"Found:{input.shape[1]}." ) x = input length = x.shape[-1] z = self._spec(input) mag = self._magnitude(z) x = mag B, C, Fq, T = x.shape # unlike previous Demucs, we always normalize because it is easier. mean = x.mean(dim=(1, 2, 3), keepdim=True) std = x.std(dim=(1, 2, 3), keepdim=True) x = (x - mean) / (1e-5 + std) # x will be the freq. branch input. # Prepare the time branch input. xt = input meant = xt.mean(dim=(1, 2), keepdim=True) stdt = xt.std(dim=(1, 2), keepdim=True) xt = (xt - meant) / (1e-5 + stdt) saved = [] # skip connections, freq. saved_t = [] # skip connections, time. lengths: List[int] = [] # saved lengths to properly remove padding, freq branch. lengths_t: List[int] = [] # saved lengths for time branch. for idx, encode in enumerate(self.freq_encoder): lengths.append(x.shape[-1]) inject = None if idx < len(self.time_encoder): # we have not yet merged branches. lengths_t.append(xt.shape[-1]) tenc = self.time_encoder[idx] xt = tenc(xt) if not tenc.empty: # save for skip connection saved_t.append(xt) else: # tenc contains just the first conv., so that now time and freq. # branches have the same shape and can be merged. inject = xt x = encode(x, inject) if idx == 0 and self.freq_emb is not None: # add frequency embedding to allow for non equivariant convolutions # over the frequency axis. frs = torch.arange(x.shape[-2], device=x.device) emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x) x = x + self.freq_emb_scale * emb saved.append(x) x = torch.zeros_like(x) xt = torch.zeros_like(x) # initialize everything to zero (signal will go through u-net skips). for idx, decode in enumerate(self.freq_decoder): skip = saved.pop(-1) x, pre = decode(x, skip, lengths.pop(-1)) # `pre` contains the output just before final transposed convolution, # which is used when the freq. and time branch separate. offset = self.depth - len(self.time_decoder) if idx >= offset: tdec = self.time_decoder[idx - offset] length_t = lengths_t.pop(-1) if tdec.empty: if pre.shape[2] != 1: raise ValueError(f"If tdec empty is True, pre shape does not match {pre.shape}") pre = pre[:, :, 0] xt, _ = tdec(pre, None, length_t) else: skip = saved_t.pop(-1) xt, _ = tdec(xt, skip, length_t) if len(saved) != 0: raise AssertionError("saved is not empty") if len(lengths_t) != 0: raise AssertionError("lengths_t is not empty") if len(saved_t) != 0: raise AssertionError("saved_t is not empty") S = len(self.sources) x = x.view(B, S, -1, Fq, T) x = x * std[:, None] + mean[:, None] zout = self._mask(x) x = self._ispec(zout, length) xt = xt.view(B, S, -1, length) xt = xt * stdt[:, None] + meant[:, None] x = xt + x return x
class _DConv(torch.nn.Module): r""" New residual branches in each encoder layer. This alternates dilated convolutions, potentially with LSTMs and attention. Also before entering each residual branch, dimension is projected on a smaller subspace, e.g. of dim `channels // compress`. Args: channels (int): input/output channels for residual branch. compress (float, optional): amount of channel compression inside the branch. (default: 4) depth (int, optional): number of layers in the residual branch. Each layer has its own projection, and potentially LSTM and attention.(default: 2) init (float, optional): initial scale for LayerNorm. (default: 1e-4) norm_type (bool, optional): Norm type, either ``group_norm `` or ``none`` (Default: ``group_norm``) attn (bool, optional): use LocalAttention. (Default: ``False``) heads (int, optional): number of heads for the LocalAttention. (default: 4) ndecay (int, optional): number of decay controls in the LocalAttention. (default: 4) lstm (bool, optional): use LSTM. (Default: ``False``) kernel_size (int, optional): kernel size for the (dilated) convolutions. (default: 3) """ def __init__( self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4, norm_type: str = "group_norm", attn: bool = False, heads: int = 4, ndecay: int = 4, lstm: bool = False, kernel_size: int = 3, ): super().__init__() if kernel_size % 2 == 0: raise ValueError("Kernel size should not be divisible by 2") self.channels = channels self.compress = compress self.depth = abs(depth) dilate = depth > 0 norm_fn: tp.Callable[[int], nn.Module] norm_fn = lambda d: nn.Identity() # noqa if norm_type == "group_norm": norm_fn = lambda d: nn.GroupNorm(1, d) # noqa hidden = int(channels / compress) act = nn.GELU self.layers = nn.ModuleList([]) for d in range(self.depth): dilation = pow(2, d) if dilate else 1 padding = dilation * (kernel_size // 2) mods = [ nn.Conv1d(channels, hidden, kernel_size, dilation=dilation, padding=padding), norm_fn(hidden), act(), nn.Conv1d(hidden, 2 * channels, 1), norm_fn(2 * channels), nn.GLU(1), _LayerScale(channels, init), ] if attn: mods.insert(3, _LocalState(hidden, heads=heads, ndecay=ndecay)) if lstm: mods.insert(3, _BLSTM(hidden, layers=2, skip=True)) layer = nn.Sequential(*mods) self.layers.append(layer) def forward(self, x): r"""DConv forward call Args: x (torch.Tensor): input tensor for convolution Returns: Tensor Output after being run through layers. """ for layer in self.layers: x = x + layer(x) return x class _BLSTM(torch.nn.Module): r""" BiLSTM with same hidden units as input dim. If `max_steps` is not None, input will be splitting in overlapping chunks and the LSTM applied separately on each chunk. Args: dim (int): dimensions at LSTM layer. layers (int, optional): number of LSTM layers. (default: 1) skip (bool, optional): (default: ``False``) """ def __init__(self, dim, layers: int = 1, skip: bool = False): super().__init__() self.max_steps = 200 self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) self.linear = nn.Linear(2 * dim, dim) self.skip = skip def forward(self, x: torch.Tensor) -> torch.Tensor: r"""BLSTM forward call Args: x (torch.Tensor): input tensor for BLSTM shape is `(batch_size, dim, time_steps)` Returns: Tensor Output after being run through bidirectional LSTM. Shape is `(batch_size, dim, time_steps)` """ B, C, T = x.shape y = x framed = False width = 0 stride = 0 nframes = 0 if self.max_steps is not None and T > self.max_steps: width = self.max_steps stride = width // 2 frames = _unfold(x, width, stride) nframes = frames.shape[2] framed = True x = frames.permute(0, 2, 1, 3).reshape(-1, C, width) x = x.permute(2, 0, 1) x = self.lstm(x)[0] x = self.linear(x) x = x.permute(1, 2, 0) if framed: out = [] frames = x.reshape(B, -1, C, width) limit = stride // 2 for k in range(nframes): if k == 0: out.append(frames[:, k, :, :-limit]) elif k == nframes - 1: out.append(frames[:, k, :, limit:]) else: out.append(frames[:, k, :, limit:-limit]) out = torch.cat(out, -1) out = out[..., :T] x = out if self.skip: x = x + y return x class _LocalState(nn.Module): """Local state allows to have attention based only on data (no positional embedding), but while setting a constraint on the time window (e.g. decaying penalty term). Also a failed experiments with trying to provide some frequency based attention. """ def __init__(self, channels: int, heads: int = 4, ndecay: int = 4): r""" Args: channels (int): Size of Conv1d layers. heads (int, optional): (default: 4) ndecay (int, optional): (default: 4) """ super(_LocalState, self).__init__() if channels % heads != 0: raise ValueError("Channels must be divisible by heads.") self.heads = heads self.ndecay = ndecay self.content = nn.Conv1d(channels, channels, 1) self.query = nn.Conv1d(channels, channels, 1) self.key = nn.Conv1d(channels, channels, 1) self.query_decay = nn.Conv1d(channels, heads * ndecay, 1) if ndecay: # Initialize decay close to zero (there is a sigmoid), for maximum initial window. self.query_decay.weight.data *= 0.01 if self.query_decay.bias is None: raise ValueError("bias must not be None.") self.query_decay.bias.data[:] = -2 self.proj = nn.Conv1d(channels + heads * 0, channels, 1) def forward(self, x: torch.Tensor) -> torch.Tensor: r"""LocalState forward call Args: x (torch.Tensor): input tensor for LocalState Returns: Tensor Output after being run through LocalState layer. """ B, C, T = x.shape heads = self.heads indexes = torch.arange(T, device=x.device, dtype=x.dtype) # left index are keys, right index are queries delta = indexes[:, None] - indexes[None, :] queries = self.query(x).view(B, heads, -1, T) keys = self.key(x).view(B, heads, -1, T) # t are keys, s are queries dots = torch.einsum("bhct,bhcs->bhts", keys, queries) dots /= math.sqrt(keys.shape[2]) if self.ndecay: decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype) decay_q = self.query_decay(x).view(B, heads, -1, T) decay_q = torch.sigmoid(decay_q) / 2 decay_kernel = -decays.view(-1, 1, 1) * delta.abs() / math.sqrt(self.ndecay) dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q) # Kill self reference. dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100) weights = torch.softmax(dots, dim=2) content = self.content(x).view(B, heads, -1, T) result = torch.einsum("bhts,bhct->bhcs", weights, content) result = result.reshape(B, -1, T) return x + self.proj(result) class _LayerScale(nn.Module): """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). This rescales diagonally residual outputs close to 0 initially, then learnt. """ def __init__(self, channels: int, init: float = 0): r""" Args: channels (int): Size of rescaling init (float, optional): Scale to default to (default: 0) """ super().__init__() self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True)) self.scale.data[:] = init def forward(self, x: torch.Tensor) -> torch.Tensor: r"""LayerScale forward call Args: x (torch.Tensor): input tensor for LayerScale Returns: Tensor Output after rescaling tensor. """ return self.scale[:, None] * x def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor: """Given input of size [*OT, T], output Tensor of size [*OT, F, K] with K the kernel size, by extracting frames with the given stride. This will pad the input so that `F = ceil(T / K)`. see https://github.com/pytorch/pytorch/issues/60466 """ shape = list(a.shape[:-1]) length = int(a.shape[-1]) n_frames = math.ceil(length / stride) tgt_length = (n_frames - 1) * stride + kernel_size a = F.pad(input=a, pad=[0, tgt_length - length]) strides = [a.stride(dim) for dim in range(a.dim())] if strides[-1] != 1: raise ValueError("Data should be contiguous.") strides = strides[:-1] + [stride, 1] shape.append(n_frames) shape.append(kernel_size) return a.as_strided(shape, strides) def _rescale_module(module): r""" Rescales initial weight scale for all models within the module. """ for sub in module.modules(): if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)): std = sub.weight.std().detach() scale = (std / 0.1) ** 0.5 sub.weight.data /= scale if sub.bias is not None: sub.bias.data /= scale def _spectro(x: torch.Tensor, n_fft: int = 512, hop_length: int = 0, pad: int = 0) -> torch.Tensor: other = list(x.shape[:-1]) length = int(x.shape[-1]) x = x.reshape(-1, length) z = torch.stft( x, n_fft * (1 + pad), hop_length, window=torch.hann_window(n_fft).to(x), win_length=n_fft, normalized=True, center=True, return_complex=True, pad_mode="reflect", ) _, freqs, frame = z.shape other.extend([freqs, frame]) return z.view(other) def _ispectro(z: torch.Tensor, hop_length: int = 0, length: int = 0, pad: int = 0) -> torch.Tensor: other = list(z.shape[:-2]) freqs = int(z.shape[-2]) frames = int(z.shape[-1]) n_fft = 2 * freqs - 2 z = z.view(-1, freqs, frames) win_length = n_fft // (1 + pad) x = torch.istft( z, n_fft, hop_length, window=torch.hann_window(win_length).to(z.real), win_length=win_length, normalized=True, length=length, center=True, ) _, length = x.shape other.append(length) return x.view(other)
[docs]def hdemucs_low(sources: List[str]) -> HDemucs: """Builds low nfft (1024) version of :class:`HDemucs`, suitable for sample rates around 8 kHz. Args: sources (List[str]): See :py:func:`HDemucs`. Returns: HDemucs: HDemucs model. """ return HDemucs(sources=sources, nfft=1024, depth=5)
[docs]def hdemucs_medium(sources: List[str]) -> HDemucs: r"""Builds medium nfft (2048) version of :class:`HDemucs`, suitable for sample rates of 16-32 kHz. .. note:: Medium HDemucs has not been tested against the original Hybrid Demucs as this nfft and depth configuration is not compatible with the original implementation in https://github.com/facebookresearch/demucs Args: sources (List[str]): See :py:func:`HDemucs`. Returns: HDemucs: HDemucs model. """ return HDemucs(sources=sources, nfft=2048, depth=6)
[docs]def hdemucs_high(sources: List[str]) -> HDemucs: r"""Builds medium nfft (4096) version of :class:`HDemucs`, suitable for sample rates of 44.1-48 kHz. Args: sources (List[str]): See :py:func:`HDemucs`. Returns: HDemucs: HDemucs model. """ return HDemucs(sources=sources, nfft=4096, depth=6)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources