Source code for torchaudio.models._hdemucs
# *****************************************************************************
# MIT License
#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# *****************************************************************************
import math
import typing as tp
from typing import Any, Dict, List, Optional
import torch
from torch import nn
from torch.nn import functional as F
class _ScaledEmbedding(torch.nn.Module):
r"""Make continuous embeddings and boost learning rate
Args:
num_embeddings (int): number of embeddings
embedding_dim (int): embedding dimensions
scale (float, optional): amount to scale learning rate (Default: 10.0)
smooth (bool, optional): choose to apply smoothing (Default: ``False``)
"""
def __init__(self, num_embeddings: int, embedding_dim: int, scale: float = 10.0, smooth: bool = False):
super().__init__()
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
if smooth:
weight = torch.cumsum(self.embedding.weight.data, dim=0)
# when summing gaussian, scale raises as sqrt(n), so we normalize by that.
weight = weight / torch.arange(1, num_embeddings + 1).sqrt()[:, None]
self.embedding.weight.data[:] = weight
self.embedding.weight.data /= scale
self.scale = scale
@property
def weight(self) -> torch.Tensor:
return self.embedding.weight * self.scale
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""Forward pass for embedding with scale.
Args:
x (torch.Tensor): input tensor of shape `(num_embeddings)`
Returns:
(Tensor):
Embedding output of shape `(num_embeddings, embedding_dim)`
"""
out = self.embedding(x) * self.scale
return out
class _HEncLayer(torch.nn.Module):
r"""Encoder layer. This used both by the time and the frequency branch.
Args:
chin (int): number of input channels.
chout (int): number of output channels.
kernel_size (int, optional): Kernel size for encoder (Default: 8)
stride (int, optional): Stride for encoder layer (Default: 4)
norm_groups (int, optional): number of groups for group norm. (Default: 4)
empty (bool, optional): used to make a layer with just the first conv. this is used
before merging the time and freq. branches. (Default: ``False``)
freq (bool, optional): boolean for whether conv layer is for frequency domain (Default: ``True``)
norm_type (string, optional): Norm type, either ``group_norm `` or ``none`` (Default: ``group_norm``)
context (int, optional): context size for the 1x1 conv. (Default: 0)
dconv_kw (Dict[str, Any] or None, optional): dictionary of kwargs for the DConv class. (Default: ``None``)
pad (bool, optional): true to pad the input. Padding is done so that the output size is
always the input size / stride. (Default: ``True``)
"""
def __init__(
self,
chin: int,
chout: int,
kernel_size: int = 8,
stride: int = 4,
norm_groups: int = 4,
empty: bool = False,
freq: bool = True,
norm_type: str = "group_norm",
context: int = 0,
dconv_kw: Optional[Dict[str, Any]] = None,
pad: bool = True,
):
super().__init__()
if dconv_kw is None:
dconv_kw = {}
norm_fn = lambda d: nn.Identity() # noqa
if norm_type == "group_norm":
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
pad_val = kernel_size // 4 if pad else 0
klass = nn.Conv1d
self.freq = freq
self.kernel_size = kernel_size
self.stride = stride
self.empty = empty
self.pad = pad_val
if freq:
kernel_size = [kernel_size, 1]
stride = [stride, 1]
pad_val = [pad_val, 0]
klass = nn.Conv2d
self.conv = klass(chin, chout, kernel_size, stride, pad_val)
self.norm1 = norm_fn(chout)
if self.empty:
self.rewrite = nn.Identity()
self.norm2 = nn.Identity()
self.dconv = nn.Identity()
else:
self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
self.norm2 = norm_fn(2 * chout)
self.dconv = _DConv(chout, **dconv_kw)
def forward(self, x: torch.Tensor, inject: Optional[torch.Tensor] = None) -> torch.Tensor:
r"""Forward pass for encoding layer.
Size depends on whether frequency or time
Args:
x (torch.Tensor): tensor input of shape `(B, C, F, T)` for frequency and shape
`(B, C, T)` for time
inject (torch.Tensor, optional): on last layer, combine frequency and time branches through inject param,
same shape as x (default: ``None``)
Returns:
Tensor
output tensor after encoder layer of shape `(B, C, F / stride, T)` for frequency
and shape `(B, C, ceil(T / stride))` for time
"""
if not self.freq and x.dim() == 4:
B, C, Fr, T = x.shape
x = x.view(B, -1, T)
if not self.freq:
le = x.shape[-1]
if not le % self.stride == 0:
x = F.pad(x, (0, self.stride - (le % self.stride)))
y = self.conv(x)
if self.empty:
return y
if inject is not None:
if inject.shape[-1] != y.shape[-1]:
raise ValueError("Injection shapes do not align")
if inject.dim() == 3 and y.dim() == 4:
inject = inject[:, :, None]
y = y + inject
y = F.gelu(self.norm1(y))
if self.freq:
B, C, Fr, T = y.shape
y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
y = self.dconv(y)
y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
else:
y = self.dconv(y)
z = self.norm2(self.rewrite(y))
z = F.glu(z, dim=1)
return z
class _HDecLayer(torch.nn.Module):
r"""Decoder layer. This used both by the time and the frequency branches.
Args:
chin (int): number of input channels.
chout (int): number of output channels.
last (bool, optional): whether current layer is final layer (Default: ``False``)
kernel_size (int, optional): Kernel size for encoder (Default: 8)
stride (int): Stride for encoder layer (Default: 4)
norm_groups (int, optional): number of groups for group norm. (Default: 1)
empty (bool, optional): used to make a layer with just the first conv. this is used
before merging the time and freq. branches. (Default: ``False``)
freq (bool, optional): boolean for whether conv layer is for frequency (Default: ``True``)
norm_type (str, optional): Norm type, either ``group_norm `` or ``none`` (Default: ``group_norm``)
context (int, optional): context size for the 1x1 conv. (Default: 1)
dconv_kw (Dict[str, Any] or None, optional): dictionary of kwargs for the DConv class. (Default: ``None``)
pad (bool, optional): true to pad the input. Padding is done so that the output size is
always the input size / stride. (Default: ``True``)
"""
def __init__(
self,
chin: int,
chout: int,
last: bool = False,
kernel_size: int = 8,
stride: int = 4,
norm_groups: int = 1,
empty: bool = False,
freq: bool = True,
norm_type: str = "group_norm",
context: int = 1,
dconv_kw: Optional[Dict[str, Any]] = None,
pad: bool = True,
):
super().__init__()
if dconv_kw is None:
dconv_kw = {}
norm_fn = lambda d: nn.Identity() # noqa
if norm_type == "group_norm":
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
if pad:
if (kernel_size - stride) % 2 != 0:
raise ValueError("Kernel size and stride do not align")
pad = (kernel_size - stride) // 2
else:
pad = 0
self.pad = pad
self.last = last
self.freq = freq
self.chin = chin
self.empty = empty
self.stride = stride
self.kernel_size = kernel_size
klass = nn.Conv1d
klass_tr = nn.ConvTranspose1d
if freq:
kernel_size = [kernel_size, 1]
stride = [stride, 1]
klass = nn.Conv2d
klass_tr = nn.ConvTranspose2d
self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
self.norm2 = norm_fn(chout)
if self.empty:
self.rewrite = nn.Identity()
self.norm1 = nn.Identity()
else:
self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
self.norm1 = norm_fn(2 * chin)
def forward(self, x: torch.Tensor, skip: Optional[torch.Tensor], length):
r"""Forward pass for decoding layer.
Size depends on whether frequency or time
Args:
x (torch.Tensor): tensor input of shape `(B, C, F, T)` for frequency and shape
`(B, C, T)` for time
skip (torch.Tensor, optional): on first layer, separate frequency and time branches using param
(default: ``None``)
length (int): Size of tensor for output
Returns:
(Tensor, Tensor):
Tensor
output tensor after decoder layer of shape `(B, C, F * stride, T)` for frequency domain except last
frequency layer shape is `(B, C, kernel_size, T)`. Shape is `(B, C, stride * T)`
for time domain.
Tensor
contains the output just before final transposed convolution, which is used when the
freq. and time branch separate. Otherwise, does not matter. Shape is
`(B, C, F, T)` for frequency and `(B, C, T)` for time.
"""
if self.freq and x.dim() == 3:
B, C, T = x.shape
x = x.view(B, self.chin, -1, T)
if not self.empty:
x = x + skip
y = F.glu(self.norm1(self.rewrite(x)), dim=1)
else:
y = x
if skip is not None:
raise ValueError("Skip must be none when empty is true.")
z = self.norm2(self.conv_tr(y))
if self.freq:
if self.pad:
z = z[..., self.pad : -self.pad, :]
else:
z = z[..., self.pad : self.pad + length]
if z.shape[-1] != length:
raise ValueError("Last index of z must be equal to length")
if not self.last:
z = F.gelu(z)
return z, y
[docs]class HDemucs(torch.nn.Module):
r"""Hybrid Demucs model from
*Hybrid Spectrogram and Waveform Source Separation* :cite:`defossez2021hybrid`.
See Also:
* :class:`torchaudio.pipelines.SourceSeparationBundle`: Source separation pipeline with pre-trained models.
Args:
sources (List[str]): list of source names. List can contain the following source
options: [``"bass"``, ``"drums"``, ``"other"``, ``"mixture"``, ``"vocals"``].
audio_channels (int, optional): input/output audio channels. (Default: 2)
channels (int, optional): initial number of hidden channels. (Default: 48)
growth (int, optional): increase the number of hidden channels by this factor at each layer. (Default: 2)
nfft (int, optional): number of fft bins. Note that changing this requires careful computation of
various shape parameters and will not work out of the box for hybrid models. (Default: 4096)
depth (int, optional): number of layers in encoder and decoder (Default: 6)
freq_emb (float, optional): add frequency embedding after the first frequency layer if > 0,
the actual value controls the weight of the embedding. (Default: 0.2)
emb_scale (int, optional): equivalent to scaling the embedding learning rate (Default: 10)
emb_smooth (bool, optional): initialize the embedding with a smooth one (with respect to frequencies).
(Default: ``True``)
kernel_size (int, optional): kernel_size for encoder and decoder layers. (Default: 8)
time_stride (int, optional): stride for the final time layer, after the merge. (Default: 2)
stride (int, optional): stride for encoder and decoder layers. (Default: 4)
context (int, optional): context for 1x1 conv in the decoder. (Default: 4)
context_enc (int, optional): context for 1x1 conv in the encoder. (Default: 0)
norm_starts (int, optional): layer at which group norm starts being used.
decoder layers are numbered in reverse order. (Default: 4)
norm_groups (int, optional): number of groups for group norm. (Default: 4)
dconv_depth (int, optional): depth of residual DConv branch. (Default: 2)
dconv_comp (int, optional): compression of DConv branch. (Default: 4)
dconv_attn (int, optional): adds attention layers in DConv branch starting at this layer. (Default: 4)
dconv_lstm (int, optional): adds a LSTM layer in DConv branch starting at this layer. (Default: 4)
dconv_init (float, optional): initial scale for the DConv branch LayerScale. (Default: 1e-4)
"""
def __init__(
self,
sources: List[str],
audio_channels: int = 2,
channels: int = 48,
growth: int = 2,
nfft: int = 4096,
depth: int = 6,
freq_emb: float = 0.2,
emb_scale: int = 10,
emb_smooth: bool = True,
kernel_size: int = 8,
time_stride: int = 2,
stride: int = 4,
context: int = 1,
context_enc: int = 0,
norm_starts: int = 4,
norm_groups: int = 4,
dconv_depth: int = 2,
dconv_comp: int = 4,
dconv_attn: int = 4,
dconv_lstm: int = 4,
dconv_init: float = 1e-4,
):
super().__init__()
self.depth = depth
self.nfft = nfft
self.audio_channels = audio_channels
self.sources = sources
self.kernel_size = kernel_size
self.context = context
self.stride = stride
self.channels = channels
self.hop_length = self.nfft // 4
self.freq_emb = None
self.freq_encoder = nn.ModuleList()
self.freq_decoder = nn.ModuleList()
self.time_encoder = nn.ModuleList()
self.time_decoder = nn.ModuleList()
chin = audio_channels
chin_z = chin * 2 # number of channels for the freq branch
chout = channels
chout_z = channels
freqs = self.nfft // 2
for index in range(self.depth):
lstm = index >= dconv_lstm
attn = index >= dconv_attn
norm_type = "group_norm" if index >= norm_starts else "none"
freq = freqs > 1
stri = stride
ker = kernel_size
if not freq:
if freqs != 1:
raise ValueError("When freq is false, freqs must be 1.")
ker = time_stride * 2
stri = time_stride
pad = True
last_freq = False
if freq and freqs <= kernel_size:
ker = freqs
pad = False
last_freq = True
kw = {
"kernel_size": ker,
"stride": stri,
"freq": freq,
"pad": pad,
"norm_type": norm_type,
"norm_groups": norm_groups,
"dconv_kw": {
"lstm": lstm,
"attn": attn,
"depth": dconv_depth,
"compress": dconv_comp,
"init": dconv_init,
},
}
kwt = dict(kw)
kwt["freq"] = 0
kwt["kernel_size"] = kernel_size
kwt["stride"] = stride
kwt["pad"] = True
kw_dec = dict(kw)
if last_freq:
chout_z = max(chout, chout_z)
chout = chout_z
enc = _HEncLayer(chin_z, chout_z, context=context_enc, **kw)
if freq:
if last_freq is True and nfft == 2048:
kwt["stride"] = 2
kwt["kernel_size"] = 4
tenc = _HEncLayer(chin, chout, context=context_enc, empty=last_freq, **kwt)
self.time_encoder.append(tenc)
self.freq_encoder.append(enc)
if index == 0:
chin = self.audio_channels * len(self.sources)
chin_z = chin * 2
dec = _HDecLayer(chout_z, chin_z, last=index == 0, context=context, **kw_dec)
if freq:
tdec = _HDecLayer(chout, chin, empty=last_freq, last=index == 0, context=context, **kwt)
self.time_decoder.insert(0, tdec)
self.freq_decoder.insert(0, dec)
chin = chout
chin_z = chout_z
chout = int(growth * chout)
chout_z = int(growth * chout_z)
if freq:
if freqs <= kernel_size:
freqs = 1
else:
freqs //= stride
if index == 0 and freq_emb:
self.freq_emb = _ScaledEmbedding(freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
self.freq_emb_scale = freq_emb
_rescale_module(self)
def _spec(self, x):
hl = self.hop_length
nfft = self.nfft
x0 = x # noqa
# We re-pad the signal in order to keep the property
# that the size of the output is exactly the size of the input
# divided by the stride (here hop_length), when divisible.
# This is achieved by padding by 1/4th of the kernel size (here nfft).
# which is not supported by torch.stft.
# Having all convolution operations follow this convention allow to easily
# align the time and frequency branches later on.
if hl != nfft // 4:
raise ValueError("Hop length must be nfft // 4")
le = int(math.ceil(x.shape[-1] / hl))
pad = hl // 2 * 3
x = self._pad1d(x, pad, pad + le * hl - x.shape[-1], mode="reflect")
z = _spectro(x, nfft, hl)[..., :-1, :]
if z.shape[-1] != le + 4:
raise ValueError("Spectrogram's last dimension must be 4 + input size divided by stride")
z = z[..., 2 : 2 + le]
return z
def _ispec(self, z, length=None):
hl = self.hop_length
z = F.pad(z, [0, 0, 0, 1])
z = F.pad(z, [2, 2])
pad = hl // 2 * 3
le = hl * int(math.ceil(length / hl)) + 2 * pad
x = _ispectro(z, hl, length=le)
x = x[..., pad : pad + length]
return x
def _pad1d(self, x: torch.Tensor, padding_left: int, padding_right: int, mode: str = "zero", value: float = 0.0):
"""Wrapper around F.pad, in order for reflect padding when num_frames is shorter than max_pad.
Add extra zero padding around in order for padding to not break."""
length = x.shape[-1]
if mode == "reflect":
max_pad = max(padding_left, padding_right)
if length <= max_pad:
x = F.pad(x, (0, max_pad - length + 1))
return F.pad(x, (padding_left, padding_right), mode, value)
def _magnitude(self, z):
# move the complex dimension to the channel one.
B, C, Fr, T = z.shape
m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
m = m.reshape(B, C * 2, Fr, T)
return m
def _mask(self, m):
# `m` is a full spectrogram and `z` is ignored.
B, S, C, Fr, T = m.shape
out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
out = torch.view_as_complex(out.contiguous())
return out
[docs] def forward(self, input: torch.Tensor):
r"""HDemucs forward call
Args:
input (torch.Tensor): input mixed tensor of shape `(batch_size, channel, num_frames)`
Returns:
Tensor
output tensor split into sources of shape `(batch_size, num_sources, channel, num_frames)`
"""
if input.ndim != 3:
raise ValueError(f"Expected 3D tensor with dimensions (batch, channel, frames). Found: {input.shape}")
if input.shape[1] != self.audio_channels:
raise ValueError(
f"The channel dimension of input Tensor must match `audio_channels` of HDemucs model. "
f"Found:{input.shape[1]}."
)
x = input
length = x.shape[-1]
z = self._spec(input)
mag = self._magnitude(z)
x = mag
B, C, Fq, T = x.shape
# unlike previous Demucs, we always normalize because it is easier.
mean = x.mean(dim=(1, 2, 3), keepdim=True)
std = x.std(dim=(1, 2, 3), keepdim=True)
x = (x - mean) / (1e-5 + std)
# x will be the freq. branch input.
# Prepare the time branch input.
xt = input
meant = xt.mean(dim=(1, 2), keepdim=True)
stdt = xt.std(dim=(1, 2), keepdim=True)
xt = (xt - meant) / (1e-5 + stdt)
saved = [] # skip connections, freq.
saved_t = [] # skip connections, time.
lengths: List[int] = [] # saved lengths to properly remove padding, freq branch.
lengths_t: List[int] = [] # saved lengths for time branch.
for idx, encode in enumerate(self.freq_encoder):
lengths.append(x.shape[-1])
inject = None
if idx < len(self.time_encoder):
# we have not yet merged branches.
lengths_t.append(xt.shape[-1])
tenc = self.time_encoder[idx]
xt = tenc(xt)
if not tenc.empty:
# save for skip connection
saved_t.append(xt)
else:
# tenc contains just the first conv., so that now time and freq.
# branches have the same shape and can be merged.
inject = xt
x = encode(x, inject)
if idx == 0 and self.freq_emb is not None:
# add frequency embedding to allow for non equivariant convolutions
# over the frequency axis.
frs = torch.arange(x.shape[-2], device=x.device)
emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
x = x + self.freq_emb_scale * emb
saved.append(x)
x = torch.zeros_like(x)
xt = torch.zeros_like(x)
# initialize everything to zero (signal will go through u-net skips).
for idx, decode in enumerate(self.freq_decoder):
skip = saved.pop(-1)
x, pre = decode(x, skip, lengths.pop(-1))
# `pre` contains the output just before final transposed convolution,
# which is used when the freq. and time branch separate.
offset = self.depth - len(self.time_decoder)
if idx >= offset:
tdec = self.time_decoder[idx - offset]
length_t = lengths_t.pop(-1)
if tdec.empty:
if pre.shape[2] != 1:
raise ValueError(f"If tdec empty is True, pre shape does not match {pre.shape}")
pre = pre[:, :, 0]
xt, _ = tdec(pre, None, length_t)
else:
skip = saved_t.pop(-1)
xt, _ = tdec(xt, skip, length_t)
if len(saved) != 0:
raise AssertionError("saved is not empty")
if len(lengths_t) != 0:
raise AssertionError("lengths_t is not empty")
if len(saved_t) != 0:
raise AssertionError("saved_t is not empty")
S = len(self.sources)
x = x.view(B, S, -1, Fq, T)
x = x * std[:, None] + mean[:, None]
zout = self._mask(x)
x = self._ispec(zout, length)
xt = xt.view(B, S, -1, length)
xt = xt * stdt[:, None] + meant[:, None]
x = xt + x
return x
class _DConv(torch.nn.Module):
r"""
New residual branches in each encoder layer.
This alternates dilated convolutions, potentially with LSTMs and attention.
Also before entering each residual branch, dimension is projected on a smaller subspace,
e.g. of dim `channels // compress`.
Args:
channels (int): input/output channels for residual branch.
compress (float, optional): amount of channel compression inside the branch. (default: 4)
depth (int, optional): number of layers in the residual branch. Each layer has its own
projection, and potentially LSTM and attention.(default: 2)
init (float, optional): initial scale for LayerNorm. (default: 1e-4)
norm_type (bool, optional): Norm type, either ``group_norm `` or ``none`` (Default: ``group_norm``)
attn (bool, optional): use LocalAttention. (Default: ``False``)
heads (int, optional): number of heads for the LocalAttention. (default: 4)
ndecay (int, optional): number of decay controls in the LocalAttention. (default: 4)
lstm (bool, optional): use LSTM. (Default: ``False``)
kernel_size (int, optional): kernel size for the (dilated) convolutions. (default: 3)
"""
def __init__(
self,
channels: int,
compress: float = 4,
depth: int = 2,
init: float = 1e-4,
norm_type: str = "group_norm",
attn: bool = False,
heads: int = 4,
ndecay: int = 4,
lstm: bool = False,
kernel_size: int = 3,
):
super().__init__()
if kernel_size % 2 == 0:
raise ValueError("Kernel size should not be divisible by 2")
self.channels = channels
self.compress = compress
self.depth = abs(depth)
dilate = depth > 0
norm_fn: tp.Callable[[int], nn.Module]
norm_fn = lambda d: nn.Identity() # noqa
if norm_type == "group_norm":
norm_fn = lambda d: nn.GroupNorm(1, d) # noqa
hidden = int(channels / compress)
act = nn.GELU
self.layers = nn.ModuleList([])
for d in range(self.depth):
dilation = pow(2, d) if dilate else 1
padding = dilation * (kernel_size // 2)
mods = [
nn.Conv1d(channels, hidden, kernel_size, dilation=dilation, padding=padding),
norm_fn(hidden),
act(),
nn.Conv1d(hidden, 2 * channels, 1),
norm_fn(2 * channels),
nn.GLU(1),
_LayerScale(channels, init),
]
if attn:
mods.insert(3, _LocalState(hidden, heads=heads, ndecay=ndecay))
if lstm:
mods.insert(3, _BLSTM(hidden, layers=2, skip=True))
layer = nn.Sequential(*mods)
self.layers.append(layer)
def forward(self, x):
r"""DConv forward call
Args:
x (torch.Tensor): input tensor for convolution
Returns:
Tensor
Output after being run through layers.
"""
for layer in self.layers:
x = x + layer(x)
return x
class _BLSTM(torch.nn.Module):
r"""
BiLSTM with same hidden units as input dim.
If `max_steps` is not None, input will be splitting in overlapping
chunks and the LSTM applied separately on each chunk.
Args:
dim (int): dimensions at LSTM layer.
layers (int, optional): number of LSTM layers. (default: 1)
skip (bool, optional): (default: ``False``)
"""
def __init__(self, dim, layers: int = 1, skip: bool = False):
super().__init__()
self.max_steps = 200
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
self.linear = nn.Linear(2 * dim, dim)
self.skip = skip
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""BLSTM forward call
Args:
x (torch.Tensor): input tensor for BLSTM shape is `(batch_size, dim, time_steps)`
Returns:
Tensor
Output after being run through bidirectional LSTM. Shape is `(batch_size, dim, time_steps)`
"""
B, C, T = x.shape
y = x
framed = False
width = 0
stride = 0
nframes = 0
if self.max_steps is not None and T > self.max_steps:
width = self.max_steps
stride = width // 2
frames = _unfold(x, width, stride)
nframes = frames.shape[2]
framed = True
x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
x = x.permute(2, 0, 1)
x = self.lstm(x)[0]
x = self.linear(x)
x = x.permute(1, 2, 0)
if framed:
out = []
frames = x.reshape(B, -1, C, width)
limit = stride // 2
for k in range(nframes):
if k == 0:
out.append(frames[:, k, :, :-limit])
elif k == nframes - 1:
out.append(frames[:, k, :, limit:])
else:
out.append(frames[:, k, :, limit:-limit])
out = torch.cat(out, -1)
out = out[..., :T]
x = out
if self.skip:
x = x + y
return x
class _LocalState(nn.Module):
"""Local state allows to have attention based only on data (no positional embedding),
but while setting a constraint on the time window (e.g. decaying penalty term).
Also a failed experiments with trying to provide some frequency based attention.
"""
def __init__(self, channels: int, heads: int = 4, ndecay: int = 4):
r"""
Args:
channels (int): Size of Conv1d layers.
heads (int, optional): (default: 4)
ndecay (int, optional): (default: 4)
"""
super(_LocalState, self).__init__()
if channels % heads != 0:
raise ValueError("Channels must be divisible by heads.")
self.heads = heads
self.ndecay = ndecay
self.content = nn.Conv1d(channels, channels, 1)
self.query = nn.Conv1d(channels, channels, 1)
self.key = nn.Conv1d(channels, channels, 1)
self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
if ndecay:
# Initialize decay close to zero (there is a sigmoid), for maximum initial window.
self.query_decay.weight.data *= 0.01
if self.query_decay.bias is None:
raise ValueError("bias must not be None.")
self.query_decay.bias.data[:] = -2
self.proj = nn.Conv1d(channels + heads * 0, channels, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""LocalState forward call
Args:
x (torch.Tensor): input tensor for LocalState
Returns:
Tensor
Output after being run through LocalState layer.
"""
B, C, T = x.shape
heads = self.heads
indexes = torch.arange(T, device=x.device, dtype=x.dtype)
# left index are keys, right index are queries
delta = indexes[:, None] - indexes[None, :]
queries = self.query(x).view(B, heads, -1, T)
keys = self.key(x).view(B, heads, -1, T)
# t are keys, s are queries
dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
dots /= math.sqrt(keys.shape[2])
if self.ndecay:
decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
decay_q = self.query_decay(x).view(B, heads, -1, T)
decay_q = torch.sigmoid(decay_q) / 2
decay_kernel = -decays.view(-1, 1, 1) * delta.abs() / math.sqrt(self.ndecay)
dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
# Kill self reference.
dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
weights = torch.softmax(dots, dim=2)
content = self.content(x).view(B, heads, -1, T)
result = torch.einsum("bhts,bhct->bhcs", weights, content)
result = result.reshape(B, -1, T)
return x + self.proj(result)
class _LayerScale(nn.Module):
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
This rescales diagonally residual outputs close to 0 initially, then learnt.
"""
def __init__(self, channels: int, init: float = 0):
r"""
Args:
channels (int): Size of rescaling
init (float, optional): Scale to default to (default: 0)
"""
super().__init__()
self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
self.scale.data[:] = init
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""LayerScale forward call
Args:
x (torch.Tensor): input tensor for LayerScale
Returns:
Tensor
Output after rescaling tensor.
"""
return self.scale[:, None] * x
def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor:
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K]
with K the kernel size, by extracting frames with the given stride.
This will pad the input so that `F = ceil(T / K)`.
see https://github.com/pytorch/pytorch/issues/60466
"""
shape = list(a.shape[:-1])
length = int(a.shape[-1])
n_frames = math.ceil(length / stride)
tgt_length = (n_frames - 1) * stride + kernel_size
a = F.pad(input=a, pad=[0, tgt_length - length])
strides = [a.stride(dim) for dim in range(a.dim())]
if strides[-1] != 1:
raise ValueError("Data should be contiguous.")
strides = strides[:-1] + [stride, 1]
shape.append(n_frames)
shape.append(kernel_size)
return a.as_strided(shape, strides)
def _rescale_module(module):
r"""
Rescales initial weight scale for all models within the module.
"""
for sub in module.modules():
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)):
std = sub.weight.std().detach()
scale = (std / 0.1) ** 0.5
sub.weight.data /= scale
if sub.bias is not None:
sub.bias.data /= scale
def _spectro(x: torch.Tensor, n_fft: int = 512, hop_length: int = 0, pad: int = 0) -> torch.Tensor:
other = list(x.shape[:-1])
length = int(x.shape[-1])
x = x.reshape(-1, length)
z = torch.stft(
x,
n_fft * (1 + pad),
hop_length,
window=torch.hann_window(n_fft).to(x),
win_length=n_fft,
normalized=True,
center=True,
return_complex=True,
pad_mode="reflect",
)
_, freqs, frame = z.shape
other.extend([freqs, frame])
return z.view(other)
def _ispectro(z: torch.Tensor, hop_length: int = 0, length: int = 0, pad: int = 0) -> torch.Tensor:
other = list(z.shape[:-2])
freqs = int(z.shape[-2])
frames = int(z.shape[-1])
n_fft = 2 * freqs - 2
z = z.view(-1, freqs, frames)
win_length = n_fft // (1 + pad)
x = torch.istft(
z,
n_fft,
hop_length,
window=torch.hann_window(win_length).to(z.real),
win_length=win_length,
normalized=True,
length=length,
center=True,
)
_, length = x.shape
other.append(length)
return x.view(other)
[docs]def hdemucs_low(sources: List[str]) -> HDemucs:
"""Builds low nfft (1024) version of :class:`HDemucs`, suitable for sample rates around 8 kHz.
Args:
sources (List[str]): See :py:func:`HDemucs`.
Returns:
HDemucs:
HDemucs model.
"""
return HDemucs(sources=sources, nfft=1024, depth=5)
[docs]def hdemucs_medium(sources: List[str]) -> HDemucs:
r"""Builds medium nfft (2048) version of :class:`HDemucs`, suitable for sample rates of 16-32 kHz.
.. note::
Medium HDemucs has not been tested against the original Hybrid Demucs as this nfft and depth configuration is
not compatible with the original implementation in https://github.com/facebookresearch/demucs
Args:
sources (List[str]): See :py:func:`HDemucs`.
Returns:
HDemucs:
HDemucs model.
"""
return HDemucs(sources=sources, nfft=2048, depth=6)
[docs]def hdemucs_high(sources: List[str]) -> HDemucs:
r"""Builds medium nfft (4096) version of :class:`HDemucs`, suitable for sample rates of 44.1-48 kHz.
Args:
sources (List[str]): See :py:func:`HDemucs`.
Returns:
HDemucs:
HDemucs model.
"""
return HDemucs(sources=sources, nfft=4096, depth=6)