Spaces:
Running
Running
Upload 32 files
Browse files- models/mossformer2_sr/__init__.py +0 -0
- models/mossformer2_sr/__pycache__/__init__.cpython-312.pyc +0 -0
- models/mossformer2_sr/__pycache__/__init__.cpython-38.pyc +0 -0
- models/mossformer2_sr/__pycache__/conv_module.cpython-312.pyc +0 -0
- models/mossformer2_sr/__pycache__/conv_module.cpython-38.pyc +0 -0
- models/mossformer2_sr/__pycache__/env.cpython-312.pyc +0 -0
- models/mossformer2_sr/__pycache__/fsmn.cpython-312.pyc +0 -0
- models/mossformer2_sr/__pycache__/fsmn.cpython-38.pyc +0 -0
- models/mossformer2_sr/__pycache__/generator.cpython-312.pyc +0 -0
- models/mossformer2_sr/__pycache__/layer_norm.cpython-312.pyc +0 -0
- models/mossformer2_sr/__pycache__/layer_norm.cpython-38.pyc +0 -0
- models/mossformer2_sr/__pycache__/mossformer.cpython-38.pyc +0 -0
- models/mossformer2_sr/__pycache__/mossformer2.cpython-312.pyc +0 -0
- models/mossformer2_sr/__pycache__/mossformer2.cpython-38.pyc +0 -0
- models/mossformer2_sr/__pycache__/mossformer2_block.cpython-312.pyc +0 -0
- models/mossformer2_sr/__pycache__/mossformer2_block.cpython-38.pyc +0 -0
- models/mossformer2_sr/__pycache__/mossformer2_se_wrapper.cpython-312.pyc +0 -0
- models/mossformer2_sr/__pycache__/mossformer2_se_wrapper.cpython-38.pyc +0 -0
- models/mossformer2_sr/__pycache__/mossformer2_sr_wrapper.cpython-312.pyc +0 -0
- models/mossformer2_sr/__pycache__/mossformer_block.cpython-38.pyc +0 -0
- models/mossformer2_sr/__pycache__/snake.cpython-312.pyc +0 -0
- models/mossformer2_sr/__pycache__/utils.cpython-312.pyc +0 -0
- models/mossformer2_sr/conv_module.py +388 -0
- models/mossformer2_sr/env.py +15 -0
- models/mossformer2_sr/fsmn.py +214 -0
- models/mossformer2_sr/generator.py +448 -0
- models/mossformer2_sr/layer_norm.py +126 -0
- models/mossformer2_sr/mossformer2.py +711 -0
- models/mossformer2_sr/mossformer2_block.py +735 -0
- models/mossformer2_sr/mossformer2_sr_wrapper.py +52 -0
- models/mossformer2_sr/snake.py +33 -0
- models/mossformer2_sr/utils.py +37 -0
models/mossformer2_sr/__init__.py
ADDED
|
File without changes
|
models/mossformer2_sr/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (198 Bytes). View file
|
|
|
models/mossformer2_sr/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (193 Bytes). View file
|
|
|
models/mossformer2_sr/__pycache__/conv_module.cpython-312.pyc
ADDED
|
Binary file (20.4 kB). View file
|
|
|
models/mossformer2_sr/__pycache__/conv_module.cpython-38.pyc
ADDED
|
Binary file (13.9 kB). View file
|
|
|
models/mossformer2_sr/__pycache__/env.cpython-312.pyc
ADDED
|
Binary file (1.19 kB). View file
|
|
|
models/mossformer2_sr/__pycache__/fsmn.cpython-312.pyc
ADDED
|
Binary file (12.4 kB). View file
|
|
|
models/mossformer2_sr/__pycache__/fsmn.cpython-38.pyc
ADDED
|
Binary file (8.51 kB). View file
|
|
|
models/mossformer2_sr/__pycache__/generator.cpython-312.pyc
ADDED
|
Binary file (24.6 kB). View file
|
|
|
models/mossformer2_sr/__pycache__/layer_norm.cpython-312.pyc
ADDED
|
Binary file (6.59 kB). View file
|
|
|
models/mossformer2_sr/__pycache__/layer_norm.cpython-38.pyc
ADDED
|
Binary file (4.24 kB). View file
|
|
|
models/mossformer2_sr/__pycache__/mossformer.cpython-38.pyc
ADDED
|
Binary file (16.3 kB). View file
|
|
|
models/mossformer2_sr/__pycache__/mossformer2.cpython-312.pyc
ADDED
|
Binary file (22.8 kB). View file
|
|
|
models/mossformer2_sr/__pycache__/mossformer2.cpython-38.pyc
ADDED
|
Binary file (15.9 kB). View file
|
|
|
models/mossformer2_sr/__pycache__/mossformer2_block.cpython-312.pyc
ADDED
|
Binary file (30.8 kB). View file
|
|
|
models/mossformer2_sr/__pycache__/mossformer2_block.cpython-38.pyc
ADDED
|
Binary file (23.5 kB). View file
|
|
|
models/mossformer2_sr/__pycache__/mossformer2_se_wrapper.cpython-312.pyc
ADDED
|
Binary file (4.05 kB). View file
|
|
|
models/mossformer2_sr/__pycache__/mossformer2_se_wrapper.cpython-38.pyc
ADDED
|
Binary file (3.6 kB). View file
|
|
|
models/mossformer2_sr/__pycache__/mossformer2_sr_wrapper.cpython-312.pyc
ADDED
|
Binary file (2.36 kB). View file
|
|
|
models/mossformer2_sr/__pycache__/mossformer_block.cpython-38.pyc
ADDED
|
Binary file (21.2 kB). View file
|
|
|
models/mossformer2_sr/__pycache__/snake.cpython-312.pyc
ADDED
|
Binary file (2.24 kB). View file
|
|
|
models/mossformer2_sr/__pycache__/utils.cpython-312.pyc
ADDED
|
Binary file (2.46 kB). View file
|
|
|
models/mossformer2_sr/conv_module.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch import Tensor
|
| 4 |
+
import torch.nn.init as init
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
EPS = 1e-8
|
| 8 |
+
|
| 9 |
+
class GlobalLayerNorm(nn.Module):
|
| 10 |
+
"""Calculate Global Layer Normalization.
|
| 11 |
+
|
| 12 |
+
Arguments
|
| 13 |
+
---------
|
| 14 |
+
dim : (int or list or torch.Size)
|
| 15 |
+
Input shape from an expected input of size.
|
| 16 |
+
eps : float
|
| 17 |
+
A value added to the denominator for numerical stability.
|
| 18 |
+
elementwise_affine : bool
|
| 19 |
+
A boolean value that when set to True,
|
| 20 |
+
this module has learnable per-element affine parameters
|
| 21 |
+
initialized to ones (for weights) and zeros (for biases).
|
| 22 |
+
|
| 23 |
+
Example
|
| 24 |
+
-------
|
| 25 |
+
>>> x = torch.randn(5, 10, 20)
|
| 26 |
+
>>> GLN = GlobalLayerNorm(10, 3)
|
| 27 |
+
>>> x_norm = GLN(x)
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True):
|
| 31 |
+
super(GlobalLayerNorm, self).__init__()
|
| 32 |
+
self.dim = dim
|
| 33 |
+
self.eps = eps
|
| 34 |
+
self.elementwise_affine = elementwise_affine
|
| 35 |
+
|
| 36 |
+
if self.elementwise_affine:
|
| 37 |
+
if shape == 3:
|
| 38 |
+
self.weight = nn.Parameter(torch.ones(self.dim, 1))
|
| 39 |
+
self.bias = nn.Parameter(torch.zeros(self.dim, 1))
|
| 40 |
+
if shape == 4:
|
| 41 |
+
self.weight = nn.Parameter(torch.ones(self.dim, 1, 1))
|
| 42 |
+
self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1))
|
| 43 |
+
else:
|
| 44 |
+
self.register_parameter("weight", None)
|
| 45 |
+
self.register_parameter("bias", None)
|
| 46 |
+
|
| 47 |
+
def forward(self, x):
|
| 48 |
+
"""Returns the normalized tensor.
|
| 49 |
+
|
| 50 |
+
Arguments
|
| 51 |
+
---------
|
| 52 |
+
x : torch.Tensor
|
| 53 |
+
Tensor of size [N, C, K, S] or [N, C, L].
|
| 54 |
+
"""
|
| 55 |
+
# x = N x C x K x S or N x C x L
|
| 56 |
+
# N x 1 x 1
|
| 57 |
+
# cln: mean,var N x 1 x K x S
|
| 58 |
+
# gln: mean,var N x 1 x 1
|
| 59 |
+
if x.dim() == 3:
|
| 60 |
+
mean = torch.mean(x, (1, 2), keepdim=True)
|
| 61 |
+
var = torch.mean((x - mean) ** 2, (1, 2), keepdim=True)
|
| 62 |
+
if self.elementwise_affine:
|
| 63 |
+
x = (
|
| 64 |
+
self.weight * (x - mean) / torch.sqrt(var + self.eps)
|
| 65 |
+
+ self.bias
|
| 66 |
+
)
|
| 67 |
+
else:
|
| 68 |
+
x = (x - mean) / torch.sqrt(var + self.eps)
|
| 69 |
+
|
| 70 |
+
if x.dim() == 4:
|
| 71 |
+
mean = torch.mean(x, (1, 2, 3), keepdim=True)
|
| 72 |
+
var = torch.mean((x - mean) ** 2, (1, 2, 3), keepdim=True)
|
| 73 |
+
if self.elementwise_affine:
|
| 74 |
+
x = (
|
| 75 |
+
self.weight * (x - mean) / torch.sqrt(var + self.eps)
|
| 76 |
+
+ self.bias
|
| 77 |
+
)
|
| 78 |
+
else:
|
| 79 |
+
x = (x - mean) / torch.sqrt(var + self.eps)
|
| 80 |
+
return x
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class CumulativeLayerNorm(nn.LayerNorm):
|
| 84 |
+
"""Calculate Cumulative Layer Normalization.
|
| 85 |
+
|
| 86 |
+
Arguments
|
| 87 |
+
---------
|
| 88 |
+
dim : int
|
| 89 |
+
Dimension that you want to normalize.
|
| 90 |
+
elementwise_affine : True
|
| 91 |
+
Learnable per-element affine parameters.
|
| 92 |
+
|
| 93 |
+
Example
|
| 94 |
+
-------
|
| 95 |
+
>>> x = torch.randn(5, 10, 20)
|
| 96 |
+
>>> CLN = CumulativeLayerNorm(10)
|
| 97 |
+
>>> x_norm = CLN(x)
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def __init__(self, dim, elementwise_affine=True):
|
| 101 |
+
super(CumulativeLayerNorm, self).__init__(
|
| 102 |
+
dim, elementwise_affine=elementwise_affine, eps=1e-8
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
def forward(self, x):
|
| 106 |
+
"""Returns the normalized tensor.
|
| 107 |
+
|
| 108 |
+
Arguments
|
| 109 |
+
---------
|
| 110 |
+
x : torch.Tensor
|
| 111 |
+
Tensor size [N, C, K, S] or [N, C, L]
|
| 112 |
+
"""
|
| 113 |
+
# x: N x C x K x S or N x C x L
|
| 114 |
+
# N x K x S x C
|
| 115 |
+
if x.dim() == 4:
|
| 116 |
+
x = x.permute(0, 2, 3, 1).contiguous()
|
| 117 |
+
# N x K x S x C == only channel norm
|
| 118 |
+
x = super().forward(x)
|
| 119 |
+
# N x C x K x S
|
| 120 |
+
x = x.permute(0, 3, 1, 2).contiguous()
|
| 121 |
+
if x.dim() == 3:
|
| 122 |
+
x = torch.transpose(x, 1, 2)
|
| 123 |
+
# N x L x C == only channel norm
|
| 124 |
+
x = super().forward(x)
|
| 125 |
+
# N x C x L
|
| 126 |
+
x = torch.transpose(x, 1, 2)
|
| 127 |
+
return x
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def select_norm(norm, dim, shape):
|
| 131 |
+
"""Just a wrapper to select the normalization type.
|
| 132 |
+
"""
|
| 133 |
+
|
| 134 |
+
if norm == "gln":
|
| 135 |
+
return GlobalLayerNorm(dim, shape, elementwise_affine=True)
|
| 136 |
+
if norm == "cln":
|
| 137 |
+
return CumulativeLayerNorm(dim, elementwise_affine=True)
|
| 138 |
+
if norm == "ln":
|
| 139 |
+
return nn.GroupNorm(1, dim, eps=1e-8)
|
| 140 |
+
else:
|
| 141 |
+
return nn.BatchNorm1d(dim)
|
| 142 |
+
|
| 143 |
+
class Swish(nn.Module):
|
| 144 |
+
"""
|
| 145 |
+
Swish is a smooth, non-monotonic function that consistently matches or outperforms ReLU on deep networks applied
|
| 146 |
+
to a variety of challenging domains such as Image classification and Machine translation.
|
| 147 |
+
"""
|
| 148 |
+
def __init__(self):
|
| 149 |
+
super(Swish, self).__init__()
|
| 150 |
+
|
| 151 |
+
def forward(self, inputs: Tensor) -> Tensor:
|
| 152 |
+
return inputs * inputs.sigmoid()
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class GLU(nn.Module):
|
| 156 |
+
"""
|
| 157 |
+
The gating mechanism is called Gated Linear Units (GLU), which was first introduced for natural language processing
|
| 158 |
+
in the paper “Language Modeling with Gated Convolutional Networks”
|
| 159 |
+
"""
|
| 160 |
+
def __init__(self, dim: int) -> None:
|
| 161 |
+
super(GLU, self).__init__()
|
| 162 |
+
self.dim = dim
|
| 163 |
+
|
| 164 |
+
def forward(self, inputs: Tensor) -> Tensor:
|
| 165 |
+
outputs, gate = inputs.chunk(2, dim=self.dim)
|
| 166 |
+
return outputs * gate.sigmoid()
|
| 167 |
+
|
| 168 |
+
class Transpose(nn.Module):
|
| 169 |
+
""" Wrapper class of torch.transpose() for Sequential module. """
|
| 170 |
+
def __init__(self, shape: tuple):
|
| 171 |
+
super(Transpose, self).__init__()
|
| 172 |
+
self.shape = shape
|
| 173 |
+
|
| 174 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 175 |
+
return x.transpose(*self.shape)
|
| 176 |
+
|
| 177 |
+
class Linear(nn.Module):
|
| 178 |
+
"""
|
| 179 |
+
Wrapper class of torch.nn.Linear
|
| 180 |
+
Weight initialize by xavier initialization and bias initialize to zeros.
|
| 181 |
+
"""
|
| 182 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
|
| 183 |
+
super(Linear, self).__init__()
|
| 184 |
+
self.linear = nn.Linear(in_features, out_features, bias=bias)
|
| 185 |
+
init.xavier_uniform_(self.linear.weight)
|
| 186 |
+
if bias:
|
| 187 |
+
init.zeros_(self.linear.bias)
|
| 188 |
+
|
| 189 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 190 |
+
return self.linear(x)
|
| 191 |
+
|
| 192 |
+
class DepthwiseConv1d(nn.Module):
|
| 193 |
+
"""
|
| 194 |
+
When groups == in_channels and out_channels == K * in_channels, where K is a positive integer,
|
| 195 |
+
this operation is termed in literature as depthwise convolution.
|
| 196 |
+
Args:
|
| 197 |
+
in_channels (int): Number of channels in the input
|
| 198 |
+
out_channels (int): Number of channels produced by the convolution
|
| 199 |
+
kernel_size (int or tuple): Size of the convolving kernel
|
| 200 |
+
stride (int, optional): Stride of the convolution. Default: 1
|
| 201 |
+
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
|
| 202 |
+
bias (bool, optional): If True, adds a learnable bias to the output. Default: True
|
| 203 |
+
Inputs: inputs
|
| 204 |
+
- **inputs** (batch, in_channels, time): Tensor containing input vector
|
| 205 |
+
Returns: outputs
|
| 206 |
+
- **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution.
|
| 207 |
+
"""
|
| 208 |
+
def __init__(
|
| 209 |
+
self,
|
| 210 |
+
in_channels: int,
|
| 211 |
+
out_channels: int,
|
| 212 |
+
kernel_size: int,
|
| 213 |
+
stride: int = 1,
|
| 214 |
+
padding: int = 0,
|
| 215 |
+
bias: bool = False,
|
| 216 |
+
) -> None:
|
| 217 |
+
super(DepthwiseConv1d, self).__init__()
|
| 218 |
+
assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels"
|
| 219 |
+
self.conv = nn.Conv1d(
|
| 220 |
+
in_channels=in_channels,
|
| 221 |
+
out_channels=out_channels,
|
| 222 |
+
kernel_size=kernel_size,
|
| 223 |
+
groups=in_channels,
|
| 224 |
+
stride=stride,
|
| 225 |
+
padding=padding,
|
| 226 |
+
bias=bias,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
def forward(self, inputs: Tensor) -> Tensor:
|
| 230 |
+
return self.conv(inputs)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class PointwiseConv1d(nn.Module):
|
| 234 |
+
"""
|
| 235 |
+
When kernel size == 1 conv1d, this operation is termed in literature as pointwise convolution.
|
| 236 |
+
This operation often used to match dimensions.
|
| 237 |
+
Args:
|
| 238 |
+
in_channels (int): Number of channels in the input
|
| 239 |
+
out_channels (int): Number of channels produced by the convolution
|
| 240 |
+
stride (int, optional): Stride of the convolution. Default: 1
|
| 241 |
+
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
|
| 242 |
+
bias (bool, optional): If True, adds a learnable bias to the output. Default: True
|
| 243 |
+
Inputs: inputs
|
| 244 |
+
- **inputs** (batch, in_channels, time): Tensor containing input vector
|
| 245 |
+
Returns: outputs
|
| 246 |
+
- **outputs** (batch, out_channels, time): Tensor produces by pointwise 1-D convolution.
|
| 247 |
+
"""
|
| 248 |
+
def __init__(
|
| 249 |
+
self,
|
| 250 |
+
in_channels: int,
|
| 251 |
+
out_channels: int,
|
| 252 |
+
stride: int = 1,
|
| 253 |
+
padding: int = 0,
|
| 254 |
+
bias: bool = True,
|
| 255 |
+
) -> None:
|
| 256 |
+
super(PointwiseConv1d, self).__init__()
|
| 257 |
+
self.conv = nn.Conv1d(
|
| 258 |
+
in_channels=in_channels,
|
| 259 |
+
out_channels=out_channels,
|
| 260 |
+
kernel_size=1,
|
| 261 |
+
stride=stride,
|
| 262 |
+
padding=padding,
|
| 263 |
+
bias=bias,
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
def forward(self, inputs: Tensor) -> Tensor:
|
| 267 |
+
return self.conv(inputs)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
class ConvModule(nn.Module):
|
| 271 |
+
"""
|
| 272 |
+
Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
|
| 273 |
+
This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution
|
| 274 |
+
to aid training deep models.
|
| 275 |
+
Args:
|
| 276 |
+
in_channels (int): Number of channels in the input
|
| 277 |
+
kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31
|
| 278 |
+
dropout_p (float, optional): probability of dropout
|
| 279 |
+
Inputs: inputs
|
| 280 |
+
inputs (batch, time, dim): Tensor contains input sequences
|
| 281 |
+
Outputs: outputs
|
| 282 |
+
outputs (batch, time, dim): Tensor produces by conformer convolution module.
|
| 283 |
+
"""
|
| 284 |
+
def __init__(
|
| 285 |
+
self,
|
| 286 |
+
in_channels: int,
|
| 287 |
+
kernel_size: int = 17,
|
| 288 |
+
expansion_factor: int = 2,
|
| 289 |
+
dropout_p: float = 0.1,
|
| 290 |
+
) -> None:
|
| 291 |
+
super(ConvModule, self).__init__()
|
| 292 |
+
assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
|
| 293 |
+
assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2"
|
| 294 |
+
|
| 295 |
+
self.sequential = nn.Sequential(
|
| 296 |
+
Transpose(shape=(1, 2)),
|
| 297 |
+
DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2),
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
def forward(self, inputs: Tensor) -> Tensor:
|
| 301 |
+
return inputs + self.sequential(inputs).transpose(1, 2)
|
| 302 |
+
|
| 303 |
+
class ConvModule_Dilated(nn.Module):
|
| 304 |
+
"""
|
| 305 |
+
Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
|
| 306 |
+
This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution
|
| 307 |
+
to aid training deep models.
|
| 308 |
+
Args:
|
| 309 |
+
in_channels (int): Number of channels in the input
|
| 310 |
+
kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31
|
| 311 |
+
dropout_p (float, optional): probability of dropout
|
| 312 |
+
Inputs: inputs
|
| 313 |
+
inputs (batch, time, dim): Tensor contains input sequences
|
| 314 |
+
Outputs: outputs
|
| 315 |
+
outputs (batch, time, dim): Tensor produces by conformer convolution module.
|
| 316 |
+
"""
|
| 317 |
+
def __init__(
|
| 318 |
+
self,
|
| 319 |
+
in_channels: int,
|
| 320 |
+
kernel_size: int = 17,
|
| 321 |
+
expansion_factor: int = 2,
|
| 322 |
+
dropout_p: float = 0.1,
|
| 323 |
+
) -> None:
|
| 324 |
+
super(ConvModule_Gating, self).__init__()
|
| 325 |
+
assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
|
| 326 |
+
assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2"
|
| 327 |
+
self.sequential = nn.Sequential(
|
| 328 |
+
Transpose(shape=(1, 2)),
|
| 329 |
+
DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2),
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
def forward(self, inputs: Tensor) -> Tensor:
|
| 333 |
+
return inputs + self.sequential(inputs).transpose(1, 2)
|
| 334 |
+
|
| 335 |
+
class DilatedDenseNet(nn.Module):
|
| 336 |
+
def __init__(self, depth=4, lorder=20, in_channels=64):
|
| 337 |
+
super(DilatedDenseNet, self).__init__()
|
| 338 |
+
self.depth = depth
|
| 339 |
+
self.in_channels = in_channels
|
| 340 |
+
self.pad = nn.ConstantPad2d((1, 1, 1, 0), value=0.)
|
| 341 |
+
self.twidth = lorder*2-1
|
| 342 |
+
self.kernel_size = (self.twidth, 1)
|
| 343 |
+
for i in range(self.depth):
|
| 344 |
+
dil = 2 ** i
|
| 345 |
+
pad_length = lorder + (dil - 1) * (lorder - 1) - 1
|
| 346 |
+
setattr(self, 'pad{}'.format(i + 1), nn.ConstantPad2d((0, 0, pad_length, pad_length), value=0.))
|
| 347 |
+
setattr(self, 'conv{}'.format(i + 1),
|
| 348 |
+
nn.Conv2d(self.in_channels*(i+1), self.in_channels, kernel_size=self.kernel_size,
|
| 349 |
+
dilation=(dil, 1), groups=self.in_channels, bias=False))
|
| 350 |
+
setattr(self, 'norm{}'.format(i + 1), nn.InstanceNorm2d(in_channels, affine=True))
|
| 351 |
+
setattr(self, 'prelu{}'.format(i + 1), nn.PReLU(self.in_channels))
|
| 352 |
+
|
| 353 |
+
def forward(self, x):
|
| 354 |
+
x = torch.unsqueeze(x, 1)
|
| 355 |
+
x_per = x.permute(0, 3, 2, 1)
|
| 356 |
+
skip = x_per
|
| 357 |
+
for i in range(self.depth):
|
| 358 |
+
out = getattr(self, 'pad{}'.format(i + 1))(skip)
|
| 359 |
+
out = getattr(self, 'conv{}'.format(i + 1))(out)
|
| 360 |
+
out = getattr(self, 'norm{}'.format(i + 1))(out)
|
| 361 |
+
out = getattr(self, 'prelu{}'.format(i + 1))(out)
|
| 362 |
+
skip = torch.cat([out, skip], dim=1)
|
| 363 |
+
out1 = out.permute(0, 3, 2, 1)
|
| 364 |
+
return out1.squeeze(1)
|
| 365 |
+
|
| 366 |
+
class FFConvM_Dilated(nn.Module):
|
| 367 |
+
def __init__(
|
| 368 |
+
self,
|
| 369 |
+
dim_in,
|
| 370 |
+
dim_out,
|
| 371 |
+
norm_klass = nn.LayerNorm,
|
| 372 |
+
dropout = 0.1
|
| 373 |
+
):
|
| 374 |
+
super().__init__()
|
| 375 |
+
self.mdl = nn.Sequential(
|
| 376 |
+
norm_klass(dim_in),
|
| 377 |
+
nn.Linear(dim_in, dim_out),
|
| 378 |
+
nn.SiLU(),
|
| 379 |
+
DilatedDenseNet(depth=2, lorder=17, in_channels=dim_out),
|
| 380 |
+
nn.Dropout(dropout)
|
| 381 |
+
)
|
| 382 |
+
def forward(
|
| 383 |
+
self,
|
| 384 |
+
x,
|
| 385 |
+
):
|
| 386 |
+
output = self.mdl(x)
|
| 387 |
+
return output
|
| 388 |
+
|
models/mossformer2_sr/env.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class AttrDict(dict):
|
| 6 |
+
def __init__(self, *args, **kwargs):
|
| 7 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
| 8 |
+
self.__dict__ = self
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def build_env(config, config_name, path):
|
| 12 |
+
t_path = os.path.join(path, config_name)
|
| 13 |
+
if config != t_path:
|
| 14 |
+
os.makedirs(path, exist_ok=True)
|
| 15 |
+
shutil.copyfile(config, os.path.join(path, config_name))
|
models/mossformer2_sr/fsmn.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import torch as th
|
| 4 |
+
from torch.nn.parameter import Parameter
|
| 5 |
+
import numpy as np
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
class UniDeepFsmn(nn.Module):
|
| 9 |
+
"""
|
| 10 |
+
UniDeepFsmn is a neural network module that implements a single-deep feedforward sequence memory network (FSMN).
|
| 11 |
+
|
| 12 |
+
Attributes:
|
| 13 |
+
input_dim (int): Dimension of the input features.
|
| 14 |
+
output_dim (int): Dimension of the output features.
|
| 15 |
+
lorder (int): Length of the order for the convolution layers.
|
| 16 |
+
hidden_size (int): Number of hidden units in the linear layer.
|
| 17 |
+
linear (nn.Linear): Linear layer to project input features to hidden size.
|
| 18 |
+
project (nn.Linear): Linear layer to project hidden features to output dimensions.
|
| 19 |
+
conv1 (nn.Conv2d): Convolutional layer for processing the output in a grouped manner.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None):
|
| 23 |
+
super(UniDeepFsmn, self).__init__()
|
| 24 |
+
|
| 25 |
+
self.input_dim = input_dim
|
| 26 |
+
self.output_dim = output_dim
|
| 27 |
+
if lorder is None:
|
| 28 |
+
return
|
| 29 |
+
self.lorder = lorder
|
| 30 |
+
self.hidden_size = hidden_size
|
| 31 |
+
|
| 32 |
+
# Initialize the layers
|
| 33 |
+
self.linear = nn.Linear(input_dim, hidden_size) # Linear transformation to hidden size
|
| 34 |
+
self.project = nn.Linear(hidden_size, output_dim, bias=False) # Project hidden size to output dimension
|
| 35 |
+
self.conv1 = nn.Conv2d(output_dim, output_dim, [lorder + lorder - 1, 1], [1, 1], groups=output_dim, bias=False) # Convolution layer
|
| 36 |
+
|
| 37 |
+
def forward(self, input):
|
| 38 |
+
"""
|
| 39 |
+
Forward pass for the UniDeepFsmn model.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
input (torch.Tensor): Input tensor of shape (batch_size, input_dim).
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
torch.Tensor: The output tensor of the same shape as input, enhanced by the network.
|
| 46 |
+
"""
|
| 47 |
+
f1 = F.relu(self.linear(input)) # Apply linear layer followed by ReLU activation
|
| 48 |
+
p1 = self.project(f1) # Project to output dimension
|
| 49 |
+
x = th.unsqueeze(p1, 1) # Add a dimension for compatibility with Conv2d
|
| 50 |
+
x_per = x.permute(0, 3, 2, 1) # Permute dimensions for convolution
|
| 51 |
+
y = F.pad(x_per, [0, 0, self.lorder - 1, self.lorder - 1]) # Pad for causal convolution
|
| 52 |
+
out = x_per + self.conv1(y) # Add original input to convolution output
|
| 53 |
+
out1 = out.permute(0, 3, 2, 1) # Permute back to original dimensions
|
| 54 |
+
return input + out1.squeeze() # Return enhanced input
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class UniDeepFsmn_dual(nn.Module):
|
| 58 |
+
"""
|
| 59 |
+
UniDeepFsmn_dual is a neural network module that implements a dual-deep feedforward sequence memory network (FSMN).
|
| 60 |
+
|
| 61 |
+
This class extends the UniDeepFsmn by adding a second convolution layer for richer feature extraction.
|
| 62 |
+
|
| 63 |
+
Attributes:
|
| 64 |
+
input_dim (int): Dimension of the input features.
|
| 65 |
+
output_dim (int): Dimension of the output features.
|
| 66 |
+
lorder (int): Length of the order for the convolution layers.
|
| 67 |
+
hidden_size (int): Number of hidden units in the linear layer.
|
| 68 |
+
linear (nn.Linear): Linear layer to project input features to hidden size.
|
| 69 |
+
project (nn.Linear): Linear layer to project hidden features to output dimensions.
|
| 70 |
+
conv1 (nn.Conv2d): First convolutional layer for processing the output.
|
| 71 |
+
conv2 (nn.Conv2d): Second convolutional layer for further processing the features.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None):
|
| 75 |
+
super(UniDeepFsmn_dual, self).__init__()
|
| 76 |
+
|
| 77 |
+
self.input_dim = input_dim
|
| 78 |
+
self.output_dim = output_dim
|
| 79 |
+
if lorder is None:
|
| 80 |
+
return
|
| 81 |
+
self.lorder = lorder
|
| 82 |
+
self.hidden_size = hidden_size
|
| 83 |
+
|
| 84 |
+
# Initialize the layers
|
| 85 |
+
self.linear = nn.Linear(input_dim, hidden_size) # Linear transformation to hidden size
|
| 86 |
+
self.project = nn.Linear(hidden_size, output_dim, bias=False) # Project hidden size to output dimension
|
| 87 |
+
self.conv1 = nn.Conv2d(output_dim, output_dim, [lorder + lorder - 1, 1], [1, 1], groups=output_dim, bias=False) # First convolution layer
|
| 88 |
+
self.conv2 = nn.Conv2d(output_dim, output_dim, [lorder + lorder - 1, 1], [1, 1], groups=output_dim // 4, bias=False) # Second convolution layer
|
| 89 |
+
|
| 90 |
+
def forward(self, input):
|
| 91 |
+
"""
|
| 92 |
+
Forward pass for the UniDeepFsmn_dual model.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
input (torch.Tensor): Input tensor of shape (batch_size, input_dim).
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
torch.Tensor: The output tensor of the same shape as input, enhanced by the network.
|
| 99 |
+
"""
|
| 100 |
+
f1 = F.relu(self.linear(input)) # Apply linear layer followed by ReLU activation
|
| 101 |
+
p1 = self.project(f1) # Project to output dimension
|
| 102 |
+
x = th.unsqueeze(p1, 1) # Add a dimension for compatibility with Conv2d
|
| 103 |
+
x_per = x.permute(0, 3, 2, 1) # Permute dimensions for convolution
|
| 104 |
+
y = F.pad(x_per, [0, 0, self.lorder - 1, self.lorder - 1]) # Pad for causal convolution
|
| 105 |
+
conv1_out = x_per + self.conv1(y) # Add original input to first convolution output
|
| 106 |
+
z = F.pad(conv1_out, [0, 0, self.lorder - 1, self.lorder - 1]) # Pad for second convolution
|
| 107 |
+
out = conv1_out + self.conv2(z) # Add output of second convolution
|
| 108 |
+
out1 = out.permute(0, 3, 2, 1) # Permute back to original dimensions
|
| 109 |
+
return input + out1.squeeze() # Return enhanced input
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class DilatedDenseNet(nn.Module):
|
| 113 |
+
"""
|
| 114 |
+
DilatedDenseNet implements a dense network structure with dilated convolutions.
|
| 115 |
+
|
| 116 |
+
This architecture enables wider receptive fields while maintaining a lower number of parameters.
|
| 117 |
+
It consists of multiple convolutional layers with dilation rates that increase at each layer.
|
| 118 |
+
|
| 119 |
+
Attributes:
|
| 120 |
+
depth (int): Number of convolutional layers in the network.
|
| 121 |
+
in_channels (int): Number of input channels for the first layer.
|
| 122 |
+
pad (nn.ConstantPad2d): Padding layer to maintain dimensions.
|
| 123 |
+
twidth (int): Width of the kernel used in convolution.
|
| 124 |
+
kernel_size (tuple): Kernel size for convolution operations.
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
def __init__(self, depth=4, lorder=20, in_channels=64):
|
| 128 |
+
super(DilatedDenseNet, self).__init__()
|
| 129 |
+
self.depth = depth
|
| 130 |
+
self.in_channels = in_channels
|
| 131 |
+
self.pad = nn.ConstantPad2d((1, 1, 1, 0), value=0.) # Padding for the input
|
| 132 |
+
self.twidth = lorder * 2 - 1 # Width of the kernel
|
| 133 |
+
self.kernel_size = (self.twidth, 1) # Kernel size for convolutions
|
| 134 |
+
|
| 135 |
+
# Initialize layers dynamically based on depth
|
| 136 |
+
for i in range(self.depth):
|
| 137 |
+
dil = 2 ** i # Calculate dilation rate
|
| 138 |
+
pad_length = lorder + (dil - 1) * (lorder - 1) - 1 # Calculate padding length
|
| 139 |
+
setattr(self, 'pad{}'.format(i + 1), nn.ConstantPad2d((0, 0, pad_length, pad_length), value=0.)) # Padding for dilation
|
| 140 |
+
setattr(self, 'conv{}'.format(i + 1),
|
| 141 |
+
nn.Conv2d(self.in_channels * (i + 1), self.in_channels, kernel_size=self.kernel_size,
|
| 142 |
+
dilation=(dil, 1), groups=self.in_channels, bias=False)) # Convolution layer with dilation
|
| 143 |
+
setattr(self, 'norm{}'.format(i + 1), nn.InstanceNorm2d(in_channels, affine=True)) # Normalization layer
|
| 144 |
+
setattr(self, 'prelu{}'.format(i + 1), nn.PReLU(self.in_channels)) # Activation layer
|
| 145 |
+
|
| 146 |
+
def forward(self, x):
|
| 147 |
+
"""
|
| 148 |
+
Forward pass for the DilatedDenseNet model.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
x (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width).
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
torch.Tensor: Output tensor after applying dense layers.
|
| 155 |
+
"""
|
| 156 |
+
skip = x # Initialize skip connection
|
| 157 |
+
for i in range(self.depth):
|
| 158 |
+
out = getattr(self, 'pad{}'.format(i + 1))(skip) # Apply padding
|
| 159 |
+
out = getattr(self, 'conv{}'.format(i + 1))(out) # Apply convolution
|
| 160 |
+
out = getattr(self, 'norm{}'.format(i + 1))(out) # Apply normalization
|
| 161 |
+
out = getattr(self, 'prelu{}'.format(i + 1))(out) # Apply PReLU activation
|
| 162 |
+
skip = th.cat([out, skip], dim=1) # Concatenate the output with the skip connection
|
| 163 |
+
return out # Return the final output
|
| 164 |
+
|
| 165 |
+
class UniDeepFsmn_dilated(nn.Module):
|
| 166 |
+
"""
|
| 167 |
+
UniDeepFsmn_dilated combines the UniDeepFsmn architecture with a dilated dense network
|
| 168 |
+
to enhance feature extraction while maintaining efficient computation.
|
| 169 |
+
|
| 170 |
+
Attributes:
|
| 171 |
+
input_dim (int): Dimension of the input features.
|
| 172 |
+
output_dim (int): Dimension of the output features.
|
| 173 |
+
depth (int): Depth of the dilated dense network.
|
| 174 |
+
lorder (int): Length of the order for the convolution layers.
|
| 175 |
+
hidden_size (int): Number of hidden units in the linear layer.
|
| 176 |
+
linear (nn.Linear): Linear layer to project input features to hidden size.
|
| 177 |
+
project (nn.Linear): Linear layer to project hidden features to output dimensions.
|
| 178 |
+
conv (DilatedDenseNet): Instance of the DilatedDenseNet for feature extraction.
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None, depth=2):
|
| 182 |
+
super(UniDeepFsmn_dilated, self).__init__()
|
| 183 |
+
|
| 184 |
+
self.input_dim = input_dim
|
| 185 |
+
self.output_dim = output_dim
|
| 186 |
+
self.depth = depth
|
| 187 |
+
if lorder is None:
|
| 188 |
+
return
|
| 189 |
+
self.lorder = lorder
|
| 190 |
+
self.hidden_size = hidden_size
|
| 191 |
+
|
| 192 |
+
# Initialize layers
|
| 193 |
+
self.linear = nn.Linear(input_dim, hidden_size) # Linear transformation to hidden size
|
| 194 |
+
self.project = nn.Linear(hidden_size, output_dim, bias=False) # Project hidden size to output dimension
|
| 195 |
+
self.conv = DilatedDenseNet(depth=self.depth, lorder=lorder, in_channels=output_dim) # Dilated dense network for feature extraction
|
| 196 |
+
|
| 197 |
+
def forward(self, input):
|
| 198 |
+
"""
|
| 199 |
+
Forward pass for the UniDeepFsmn_dilated model.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
input (torch.Tensor): Input tensor of shape (batch_size, input_dim).
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
torch.Tensor: The output tensor of the same shape as input, enhanced by the network.
|
| 206 |
+
"""
|
| 207 |
+
f1 = F.relu(self.linear(input)) # Apply linear layer followed by ReLU activation
|
| 208 |
+
p1 = self.project(f1) # Project to output dimension
|
| 209 |
+
x = th.unsqueeze(p1, 1) # Add a dimension for compatibility with Conv2d
|
| 210 |
+
x_per = x.permute(0, 3, 2, 1) # Permute dimensions for convolution
|
| 211 |
+
out = self.conv(x_per) # Pass through the dilated dense network
|
| 212 |
+
out1 = out.permute(0, 3, 2, 1) # Permute back to original dimensions
|
| 213 |
+
|
| 214 |
+
return input + out1.squeeze() # Return enhanced input
|
models/mossformer2_sr/generator.py
ADDED
|
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
| 5 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
| 6 |
+
from models.mossformer2_sr.utils import init_weights, get_padding
|
| 7 |
+
from models.mossformer2_sr.mossformer2 import MossFormer_MaskNet
|
| 8 |
+
from models.mossformer2_sr.snake import Snake1d
|
| 9 |
+
from typing import Optional, List, Union, Dict, Tuple
|
| 10 |
+
from models.mossformer2_sr.env import AttrDict
|
| 11 |
+
import typing
|
| 12 |
+
from torchaudio.transforms import Spectrogram, Resample
|
| 13 |
+
|
| 14 |
+
LRELU_SLOPE = 0.1
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ResBlock1(torch.nn.Module):
|
| 18 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
| 19 |
+
super(ResBlock1, self).__init__()
|
| 20 |
+
self.h = h
|
| 21 |
+
self.convs1 = nn.ModuleList([
|
| 22 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
| 23 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
| 24 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
| 25 |
+
padding=get_padding(kernel_size, dilation[1]))),
|
| 26 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
| 27 |
+
padding=get_padding(kernel_size, dilation[2])))
|
| 28 |
+
#Snake1d(channels)
|
| 29 |
+
])
|
| 30 |
+
self.convs1.apply(init_weights)
|
| 31 |
+
self.convs1_activates = nn.ModuleList([
|
| 32 |
+
Snake1d(channels),
|
| 33 |
+
Snake1d(channels),
|
| 34 |
+
Snake1d(channels)
|
| 35 |
+
])
|
| 36 |
+
self.convs2 = nn.ModuleList([
|
| 37 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
| 38 |
+
padding=get_padding(kernel_size, 1))),
|
| 39 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
| 40 |
+
padding=get_padding(kernel_size, 1))),
|
| 41 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
| 42 |
+
padding=get_padding(kernel_size, 1)))
|
| 43 |
+
#Snake1d(channels)
|
| 44 |
+
])
|
| 45 |
+
self.convs2.apply(init_weights)
|
| 46 |
+
#self.convs2_activate = Snake1d(channels)
|
| 47 |
+
self.convs2_activates = nn.ModuleList([
|
| 48 |
+
Snake1d(channels),
|
| 49 |
+
Snake1d(channels),
|
| 50 |
+
Snake1d(channels)
|
| 51 |
+
])
|
| 52 |
+
|
| 53 |
+
def forward(self, x):
|
| 54 |
+
for c1, c2, act1, act2 in zip(self.convs1, self.convs2, self.convs1_activates, self.convs2_activates):
|
| 55 |
+
#xt = F.leaky_relu(x, LRELU_SLOPE)
|
| 56 |
+
#print(f'xt: {xt.shape}')
|
| 57 |
+
xt = act1(x)
|
| 58 |
+
xt = c1(xt)
|
| 59 |
+
#xt = F.leaky_relu(xt, LRELU_SLOPE)
|
| 60 |
+
xt = act2(xt)
|
| 61 |
+
xt = c2(xt)
|
| 62 |
+
x = xt + x
|
| 63 |
+
return x
|
| 64 |
+
|
| 65 |
+
def remove_weight_norm(self):
|
| 66 |
+
for l in self.convs1:
|
| 67 |
+
remove_weight_norm(l)
|
| 68 |
+
for l in self.convs2:
|
| 69 |
+
remove_weight_norm(l)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class ResBlock2(torch.nn.Module):
|
| 73 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
|
| 74 |
+
super(ResBlock2, self).__init__()
|
| 75 |
+
self.h = h
|
| 76 |
+
self.convs = nn.ModuleList([
|
| 77 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
| 78 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
| 79 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
| 80 |
+
padding=get_padding(kernel_size, dilation[1])))
|
| 81 |
+
#Snake1d(channels)
|
| 82 |
+
])
|
| 83 |
+
self.convs.apply(init_weights)
|
| 84 |
+
#self.convs_activate = Snake1d(channels)
|
| 85 |
+
self.convs_activates = nn.ModuleList([
|
| 86 |
+
Snake1d(channels),
|
| 87 |
+
Snake1d(channels)
|
| 88 |
+
])
|
| 89 |
+
def forward(self, x):
|
| 90 |
+
for c, act in zip(self.convs, self.convs_activates):
|
| 91 |
+
#xt = F.leaky_relu(x, LRELU_SLOPE)
|
| 92 |
+
xt = act(x)
|
| 93 |
+
xt = c(xt)
|
| 94 |
+
x = xt + x
|
| 95 |
+
return x
|
| 96 |
+
|
| 97 |
+
def remove_weight_norm(self):
|
| 98 |
+
for l in self.convs:
|
| 99 |
+
remove_weight_norm(l)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class Generator(torch.nn.Module):
|
| 103 |
+
def __init__(self, h):
|
| 104 |
+
super(Generator, self).__init__()
|
| 105 |
+
self.h = h
|
| 106 |
+
self.num_kernels = len(h.resblock_kernel_sizes)
|
| 107 |
+
self.num_upsamples = len(h.upsample_rates)
|
| 108 |
+
self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
|
| 109 |
+
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
|
| 110 |
+
|
| 111 |
+
self.ups = nn.ModuleList()
|
| 112 |
+
self.snakes = nn.ModuleList()
|
| 113 |
+
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
| 114 |
+
self.snakes.append(Snake1d(h.upsample_initial_channel//(2**i)))
|
| 115 |
+
self.ups.append(weight_norm(
|
| 116 |
+
ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
|
| 117 |
+
k, u, padding=(k-u)//2)))
|
| 118 |
+
|
| 119 |
+
self.resblocks = nn.ModuleList()
|
| 120 |
+
for i in range(len(self.ups)):
|
| 121 |
+
ch = h.upsample_initial_channel//(2**(i+1))
|
| 122 |
+
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
| 123 |
+
self.resblocks.append(resblock(h, ch, k, d))
|
| 124 |
+
|
| 125 |
+
self.snake_post = Snake1d(ch)
|
| 126 |
+
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
| 127 |
+
self.ups.apply(init_weights)
|
| 128 |
+
self.conv_post.apply(init_weights)
|
| 129 |
+
def forward(self, x):
|
| 130 |
+
x = self.conv_pre(x)
|
| 131 |
+
for i in range(self.num_upsamples):
|
| 132 |
+
#x = F.leaky_relu(x, LRELU_SLOPE)
|
| 133 |
+
#print(f'x {i}: {x.shape}')
|
| 134 |
+
x = self.snakes[i](x)
|
| 135 |
+
x = self.ups[i](x)
|
| 136 |
+
xs = None
|
| 137 |
+
for j in range(self.num_kernels):
|
| 138 |
+
if xs is None:
|
| 139 |
+
xs = self.resblocks[i*self.num_kernels+j](x)
|
| 140 |
+
else:
|
| 141 |
+
xs += self.resblocks[i*self.num_kernels+j](x)
|
| 142 |
+
x = xs / self.num_kernels
|
| 143 |
+
#x = F.leaky_relu(x)
|
| 144 |
+
x = self.snake_post(x)
|
| 145 |
+
x = self.conv_post(x)
|
| 146 |
+
x = torch.tanh(x)
|
| 147 |
+
|
| 148 |
+
return x
|
| 149 |
+
|
| 150 |
+
def remove_weight_norm(self):
|
| 151 |
+
#print('Removing weight norm...')
|
| 152 |
+
for l in self.ups:
|
| 153 |
+
remove_weight_norm(l)
|
| 154 |
+
for l in self.resblocks:
|
| 155 |
+
l.remove_weight_norm()
|
| 156 |
+
remove_weight_norm(self.conv_pre)
|
| 157 |
+
remove_weight_norm(self.conv_post)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class DiscriminatorP(torch.nn.Module):
|
| 161 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
| 162 |
+
super(DiscriminatorP, self).__init__()
|
| 163 |
+
self.period = period
|
| 164 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
| 165 |
+
self.convs = nn.ModuleList([
|
| 166 |
+
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
| 167 |
+
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
| 168 |
+
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
| 169 |
+
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
| 170 |
+
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
| 171 |
+
])
|
| 172 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
| 173 |
+
|
| 174 |
+
def forward(self, x):
|
| 175 |
+
fmap = []
|
| 176 |
+
|
| 177 |
+
# 1d to 2d
|
| 178 |
+
b, c, t = x.shape
|
| 179 |
+
if t % self.period != 0: # pad first
|
| 180 |
+
n_pad = self.period - (t % self.period)
|
| 181 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
| 182 |
+
t = t + n_pad
|
| 183 |
+
x = x.view(b, c, t // self.period, self.period)
|
| 184 |
+
|
| 185 |
+
for l in self.convs:
|
| 186 |
+
x = l(x)
|
| 187 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
| 188 |
+
fmap.append(x)
|
| 189 |
+
x = self.conv_post(x)
|
| 190 |
+
fmap.append(x)
|
| 191 |
+
x = torch.flatten(x, 1, -1)
|
| 192 |
+
|
| 193 |
+
return x, fmap
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
| 197 |
+
def __init__(self):
|
| 198 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
| 199 |
+
self.discriminators = nn.ModuleList([
|
| 200 |
+
DiscriminatorP(2),
|
| 201 |
+
DiscriminatorP(3),
|
| 202 |
+
DiscriminatorP(5),
|
| 203 |
+
DiscriminatorP(7),
|
| 204 |
+
DiscriminatorP(11),
|
| 205 |
+
])
|
| 206 |
+
|
| 207 |
+
def forward(self, y, y_hat):
|
| 208 |
+
y_d_rs = []
|
| 209 |
+
y_d_gs = []
|
| 210 |
+
fmap_rs = []
|
| 211 |
+
fmap_gs = []
|
| 212 |
+
for i, d in enumerate(self.discriminators):
|
| 213 |
+
y_d_r, fmap_r = d(y)
|
| 214 |
+
y_d_g, fmap_g = d(y_hat)
|
| 215 |
+
y_d_rs.append(y_d_r)
|
| 216 |
+
fmap_rs.append(fmap_r)
|
| 217 |
+
y_d_gs.append(y_d_g)
|
| 218 |
+
fmap_gs.append(fmap_g)
|
| 219 |
+
|
| 220 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class DiscriminatorS(torch.nn.Module):
|
| 224 |
+
def __init__(self, use_spectral_norm=False):
|
| 225 |
+
super(DiscriminatorS, self).__init__()
|
| 226 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
| 227 |
+
self.convs = nn.ModuleList([
|
| 228 |
+
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
|
| 229 |
+
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
| 230 |
+
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
| 231 |
+
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
| 232 |
+
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
| 233 |
+
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
| 234 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
| 235 |
+
])
|
| 236 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
| 237 |
+
|
| 238 |
+
def forward(self, x):
|
| 239 |
+
fmap = []
|
| 240 |
+
for l in self.convs:
|
| 241 |
+
x = l(x)
|
| 242 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
| 243 |
+
fmap.append(x)
|
| 244 |
+
x = self.conv_post(x)
|
| 245 |
+
fmap.append(x)
|
| 246 |
+
x = torch.flatten(x, 1, -1)
|
| 247 |
+
|
| 248 |
+
return x, fmap
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class MultiScaleDiscriminator(torch.nn.Module):
|
| 252 |
+
def __init__(self):
|
| 253 |
+
super(MultiScaleDiscriminator, self).__init__()
|
| 254 |
+
self.discriminators = nn.ModuleList([
|
| 255 |
+
DiscriminatorS(use_spectral_norm=True),
|
| 256 |
+
DiscriminatorS(),
|
| 257 |
+
DiscriminatorS(),
|
| 258 |
+
])
|
| 259 |
+
self.meanpools = nn.ModuleList([
|
| 260 |
+
AvgPool1d(4, 2, padding=2),
|
| 261 |
+
AvgPool1d(4, 2, padding=2)
|
| 262 |
+
])
|
| 263 |
+
|
| 264 |
+
def forward(self, y, y_hat):
|
| 265 |
+
y_d_rs = []
|
| 266 |
+
y_d_gs = []
|
| 267 |
+
fmap_rs = []
|
| 268 |
+
fmap_gs = []
|
| 269 |
+
for i, d in enumerate(self.discriminators):
|
| 270 |
+
if i != 0:
|
| 271 |
+
y = self.meanpools[i-1](y)
|
| 272 |
+
y_hat = self.meanpools[i-1](y_hat)
|
| 273 |
+
y_d_r, fmap_r = d(y)
|
| 274 |
+
y_d_g, fmap_g = d(y_hat)
|
| 275 |
+
y_d_rs.append(y_d_r)
|
| 276 |
+
fmap_rs.append(fmap_r)
|
| 277 |
+
y_d_gs.append(y_d_g)
|
| 278 |
+
fmap_gs.append(fmap_g)
|
| 279 |
+
|
| 280 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
# Method based on descript-audio-codec: https://github.com/descriptinc/descript-audio-codec
|
| 284 |
+
# Modified code adapted from https://github.com/gemelo-ai/vocos under the MIT license.
|
| 285 |
+
# LICENSE is in incl_licenses directory.
|
| 286 |
+
class DiscriminatorB(nn.Module):
|
| 287 |
+
def __init__(
|
| 288 |
+
self,
|
| 289 |
+
window_length: int,
|
| 290 |
+
channels: int = 32,
|
| 291 |
+
hop_factor: float = 0.25,
|
| 292 |
+
bands: Tuple[Tuple[float, float], ...] = (
|
| 293 |
+
(0.0, 0.1),
|
| 294 |
+
(0.1, 0.25),
|
| 295 |
+
(0.25, 0.5),
|
| 296 |
+
(0.5, 0.75),
|
| 297 |
+
(0.75, 1.0),
|
| 298 |
+
),
|
| 299 |
+
):
|
| 300 |
+
super().__init__()
|
| 301 |
+
self.window_length = window_length
|
| 302 |
+
self.hop_factor = hop_factor
|
| 303 |
+
self.spec_fn = Spectrogram(
|
| 304 |
+
n_fft=window_length,
|
| 305 |
+
hop_length=int(window_length * hop_factor),
|
| 306 |
+
win_length=window_length,
|
| 307 |
+
power=None,
|
| 308 |
+
)
|
| 309 |
+
n_fft = window_length // 2 + 1
|
| 310 |
+
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
|
| 311 |
+
self.bands = bands
|
| 312 |
+
convs = lambda: nn.ModuleList(
|
| 313 |
+
[
|
| 314 |
+
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
|
| 315 |
+
weight_norm(
|
| 316 |
+
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
| 317 |
+
),
|
| 318 |
+
weight_norm(
|
| 319 |
+
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
| 320 |
+
),
|
| 321 |
+
weight_norm(
|
| 322 |
+
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
| 323 |
+
),
|
| 324 |
+
weight_norm(
|
| 325 |
+
nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))
|
| 326 |
+
),
|
| 327 |
+
]
|
| 328 |
+
)
|
| 329 |
+
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
|
| 330 |
+
|
| 331 |
+
self.conv_post = weight_norm(
|
| 332 |
+
nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))
|
| 333 |
+
)
|
| 334 |
+
def spectrogram(self, x: torch.Tensor) -> List[torch.Tensor]:
|
| 335 |
+
# Remove DC offset
|
| 336 |
+
x = x - x.mean(dim=-1, keepdims=True)
|
| 337 |
+
# Peak normalize the volume of input audio
|
| 338 |
+
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
|
| 339 |
+
x = self.spec_fn(x)
|
| 340 |
+
x = torch.view_as_real(x)
|
| 341 |
+
x = x.permute(0, 3, 2, 1) # [B, F, T, C] -> [B, C, T, F]
|
| 342 |
+
# Split into bands
|
| 343 |
+
x_bands = [x[..., b[0] : b[1]] for b in self.bands]
|
| 344 |
+
return x_bands
|
| 345 |
+
|
| 346 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
| 347 |
+
x_bands = self.spectrogram(x.squeeze(1))
|
| 348 |
+
fmap = []
|
| 349 |
+
x = []
|
| 350 |
+
|
| 351 |
+
for band, stack in zip(x_bands, self.band_convs):
|
| 352 |
+
for i, layer in enumerate(stack):
|
| 353 |
+
band = layer(band)
|
| 354 |
+
band = torch.nn.functional.leaky_relu(band, 0.1)
|
| 355 |
+
if i > 0:
|
| 356 |
+
fmap.append(band)
|
| 357 |
+
x.append(band)
|
| 358 |
+
|
| 359 |
+
x = torch.cat(x, dim=-1)
|
| 360 |
+
x = self.conv_post(x)
|
| 361 |
+
fmap.append(x)
|
| 362 |
+
|
| 363 |
+
return x, fmap
|
| 364 |
+
|
| 365 |
+
# Method based on descript-audio-codec: https://github.com/descriptinc/descript-audio-codec
|
| 366 |
+
# Modified code adapted from https://github.com/gemelo-ai/vocos under the MIT license.
|
| 367 |
+
# LICENSE is in incl_licenses directory.
|
| 368 |
+
class MultiBandDiscriminator(nn.Module):
|
| 369 |
+
def __init__(
|
| 370 |
+
self,
|
| 371 |
+
h,
|
| 372 |
+
):
|
| 373 |
+
"""
|
| 374 |
+
Multi-band multi-scale STFT discriminator, with the architecture based on https://github.com/descriptinc/descript-audio-codec.
|
| 375 |
+
and the modified code adapted from https://github.com/gemelo-ai/vocos.
|
| 376 |
+
"""
|
| 377 |
+
super().__init__()
|
| 378 |
+
# fft_sizes (list[int]): Tuple of window lengths for FFT. Defaults to [2048, 1024, 512] if not set in h.
|
| 379 |
+
self.fft_sizes = h.get("mbd_fft_sizes", [2048, 1024, 512])
|
| 380 |
+
self.discriminators = nn.ModuleList(
|
| 381 |
+
[DiscriminatorB(window_length=w) for w in self.fft_sizes]
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
| 385 |
+
List[torch.Tensor],
|
| 386 |
+
List[torch.Tensor],
|
| 387 |
+
List[List[torch.Tensor]],
|
| 388 |
+
List[List[torch.Tensor]],
|
| 389 |
+
]:
|
| 390 |
+
|
| 391 |
+
y_d_rs = []
|
| 392 |
+
y_d_gs = []
|
| 393 |
+
fmap_rs = []
|
| 394 |
+
fmap_gs = []
|
| 395 |
+
|
| 396 |
+
for d in self.discriminators:
|
| 397 |
+
y_d_r, fmap_r = d(x=y)
|
| 398 |
+
y_d_g, fmap_g = d(x=y_hat)
|
| 399 |
+
y_d_rs.append(y_d_r)
|
| 400 |
+
fmap_rs.append(fmap_r)
|
| 401 |
+
y_d_gs.append(y_d_g)
|
| 402 |
+
fmap_gs.append(fmap_g)
|
| 403 |
+
|
| 404 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 405 |
+
|
| 406 |
+
def feature_loss(fmap_r, fmap_g):
|
| 407 |
+
loss = 0
|
| 408 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
| 409 |
+
for rl, gl in zip(dr, dg):
|
| 410 |
+
loss += torch.mean(torch.abs(rl - gl))
|
| 411 |
+
|
| 412 |
+
return loss*2
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
| 416 |
+
loss = 0
|
| 417 |
+
r_losses = []
|
| 418 |
+
g_losses = []
|
| 419 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
| 420 |
+
r_loss = torch.mean((1-dr)**2)
|
| 421 |
+
g_loss = torch.mean(dg**2)
|
| 422 |
+
loss += (r_loss + g_loss)
|
| 423 |
+
r_losses.append(r_loss.item())
|
| 424 |
+
g_losses.append(g_loss.item())
|
| 425 |
+
|
| 426 |
+
return loss, r_losses, g_losses
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def generator_loss(disc_outputs):
|
| 430 |
+
loss = 0
|
| 431 |
+
gen_losses = []
|
| 432 |
+
for dg in disc_outputs:
|
| 433 |
+
l = torch.mean((1-dg)**2)
|
| 434 |
+
gen_losses.append(l)
|
| 435 |
+
loss += l
|
| 436 |
+
|
| 437 |
+
return loss, gen_losses
|
| 438 |
+
|
| 439 |
+
class Mossformer(nn.Module):
|
| 440 |
+
|
| 441 |
+
def __init__(self):
|
| 442 |
+
super(Mossformer, self).__init__()
|
| 443 |
+
self.mossformer = MossFormer_MaskNet(in_channels=80, out_channels=512, out_channels_final=80)
|
| 444 |
+
|
| 445 |
+
def forward(self, input):
|
| 446 |
+
out = self.mossformer(input)
|
| 447 |
+
return out
|
| 448 |
+
|
models/mossformer2_sr/layer_norm.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python -u
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
# Copyright 2018 Northwestern Polytechnical University (author: Ke Wang)
|
| 5 |
+
|
| 6 |
+
from __future__ import absolute_import
|
| 7 |
+
from __future__ import division
|
| 8 |
+
from __future__ import print_function
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class CLayerNorm(nn.LayerNorm):
|
| 15 |
+
"""Channel-wise layer normalization."""
|
| 16 |
+
|
| 17 |
+
def __init__(self, *args, **kwargs):
|
| 18 |
+
super(CLayerNorm, self).__init__(*args, **kwargs)
|
| 19 |
+
|
| 20 |
+
def forward(self, sample):
|
| 21 |
+
"""Forward function.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
sample: [batch_size, channels, length]
|
| 25 |
+
"""
|
| 26 |
+
if sample.dim() != 3:
|
| 27 |
+
raise RuntimeError('{} only accept 3-D tensor as input'.format(
|
| 28 |
+
self.__name__))
|
| 29 |
+
# [N, C, T] -> [N, T, C]
|
| 30 |
+
sample = torch.transpose(sample, 1, 2)
|
| 31 |
+
# LayerNorm
|
| 32 |
+
sample = super().forward(sample)
|
| 33 |
+
# [N, T, C] -> [N, C, T]
|
| 34 |
+
sample = torch.transpose(sample, 1, 2)
|
| 35 |
+
return sample
|
| 36 |
+
|
| 37 |
+
class ILayerNorm(nn.InstanceNorm1d):
|
| 38 |
+
"""Channel-wise layer normalization."""
|
| 39 |
+
|
| 40 |
+
def __init__(self, *args, **kwargs):
|
| 41 |
+
super(ILayerNorm, self).__init__(*args, **kwargs)
|
| 42 |
+
|
| 43 |
+
def forward(self, sample):
|
| 44 |
+
"""Forward function.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
sample: [batch_size, channels, length]
|
| 48 |
+
"""
|
| 49 |
+
if sample.dim() != 3:
|
| 50 |
+
raise RuntimeError('{} only accept 3-D tensor as input'.format(
|
| 51 |
+
self.__name__))
|
| 52 |
+
# [N, C, T] -> [N, T, C]
|
| 53 |
+
sample = torch.transpose(sample, 1, 2)
|
| 54 |
+
# LayerNorm
|
| 55 |
+
sample = super().forward(sample)
|
| 56 |
+
# [N, T, C] -> [N, C, T]
|
| 57 |
+
sample = torch.transpose(sample, 1, 2)
|
| 58 |
+
return sample
|
| 59 |
+
|
| 60 |
+
class GLayerNorm(nn.Module):
|
| 61 |
+
"""Global Layer Normalization for TasNet."""
|
| 62 |
+
|
| 63 |
+
def __init__(self, channels, eps=1e-5):
|
| 64 |
+
super(GLayerNorm, self).__init__()
|
| 65 |
+
self.eps = eps
|
| 66 |
+
self.norm_dim = channels
|
| 67 |
+
self.gamma = nn.Parameter(torch.Tensor(channels))
|
| 68 |
+
self.beta = nn.Parameter(torch.Tensor(channels))
|
| 69 |
+
self.reset_parameters()
|
| 70 |
+
|
| 71 |
+
def reset_parameters(self):
|
| 72 |
+
nn.init.ones_(self.gamma)
|
| 73 |
+
nn.init.zeros_(self.beta)
|
| 74 |
+
|
| 75 |
+
def forward(self, sample):
|
| 76 |
+
"""Forward function.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
sample: [batch_size, channels, length]
|
| 80 |
+
"""
|
| 81 |
+
if sample.dim() != 3:
|
| 82 |
+
raise RuntimeError('{} only accept 3-D tensor as input'.format(
|
| 83 |
+
self.__name__))
|
| 84 |
+
# [N, C, T] -> [N, T, C]
|
| 85 |
+
sample = torch.transpose(sample, 1, 2)
|
| 86 |
+
# Mean and variance [N, 1, 1]
|
| 87 |
+
mean = torch.mean(sample, (1, 2), keepdim=True)
|
| 88 |
+
var = torch.mean((sample - mean)**2, (1, 2), keepdim=True)
|
| 89 |
+
sample = (sample - mean) / torch.sqrt(var + self.eps) * \
|
| 90 |
+
self.gamma + self.beta
|
| 91 |
+
# [N, T, C] -> [N, C, T]
|
| 92 |
+
sample = torch.transpose(sample, 1, 2)
|
| 93 |
+
return sample
|
| 94 |
+
|
| 95 |
+
class _LayerNorm(nn.Module):
|
| 96 |
+
"""Layer Normalization base class."""
|
| 97 |
+
|
| 98 |
+
def __init__(self, channel_size):
|
| 99 |
+
super(_LayerNorm, self).__init__()
|
| 100 |
+
self.channel_size = channel_size
|
| 101 |
+
self.gamma = nn.Parameter(torch.ones(channel_size),
|
| 102 |
+
requires_grad=True)
|
| 103 |
+
self.beta = nn.Parameter(torch.zeros(channel_size),
|
| 104 |
+
requires_grad=True)
|
| 105 |
+
|
| 106 |
+
def apply_gain_and_bias(self, normed_x):
|
| 107 |
+
""" Assumes input of size `[batch, chanel, *]`. """
|
| 108 |
+
return (self.gamma * normed_x.transpose(1, -1) +
|
| 109 |
+
self.beta).transpose(1, -1)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class GlobLayerNorm(_LayerNorm):
|
| 113 |
+
"""Global Layer Normalization (globLN)."""
|
| 114 |
+
|
| 115 |
+
def forward(self, x):
|
| 116 |
+
""" Applies forward pass.
|
| 117 |
+
Works for any input size > 2D.
|
| 118 |
+
Args:
|
| 119 |
+
x (:class:`torch.Tensor`): Shape `[batch, chan, *]`
|
| 120 |
+
Returns:
|
| 121 |
+
:class:`torch.Tensor`: gLN_x `[batch, chan, *]`
|
| 122 |
+
"""
|
| 123 |
+
dims = list(range(1, len(x.shape)))
|
| 124 |
+
mean = x.mean(dim=dims, keepdim=True)
|
| 125 |
+
var = torch.pow(x - mean, 2).mean(dim=dims, keepdim=True)
|
| 126 |
+
return self.apply_gain_and_bias((x - mean) / (var + 1e-8).sqrt())
|
models/mossformer2_sr/mossformer2.py
ADDED
|
@@ -0,0 +1,711 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
modified from https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/lobes/models/dual_path.py
|
| 3 |
+
Author: Shengkui Zhao
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import copy
|
| 11 |
+
from models.mossformer2_sr.mossformer2_block import ScaledSinuEmbedding, MossformerBlock_GFSMN, MossformerBlock
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
EPS = 1e-8
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class GlobalLayerNorm(nn.Module):
|
| 18 |
+
"""Calculate Global Layer Normalization.
|
| 19 |
+
|
| 20 |
+
Arguments
|
| 21 |
+
---------
|
| 22 |
+
dim : (int or list or torch.Size)
|
| 23 |
+
Input shape from an expected input of size.
|
| 24 |
+
eps : float
|
| 25 |
+
A value added to the denominator for numerical stability.
|
| 26 |
+
elementwise_affine : bool
|
| 27 |
+
A boolean value that when set to True,
|
| 28 |
+
this module has learnable per-element affine parameters
|
| 29 |
+
initialized to ones (for weights) and zeros (for biases).
|
| 30 |
+
|
| 31 |
+
Example
|
| 32 |
+
-------
|
| 33 |
+
>>> x = torch.randn(5, 10, 20)
|
| 34 |
+
>>> GLN = GlobalLayerNorm(10, 3)
|
| 35 |
+
>>> x_norm = GLN(x)
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True):
|
| 39 |
+
super(GlobalLayerNorm, self).__init__()
|
| 40 |
+
self.dim = dim
|
| 41 |
+
self.eps = eps
|
| 42 |
+
self.elementwise_affine = elementwise_affine
|
| 43 |
+
|
| 44 |
+
if self.elementwise_affine:
|
| 45 |
+
if shape == 3:
|
| 46 |
+
self.weight = nn.Parameter(torch.ones(self.dim, 1))
|
| 47 |
+
self.bias = nn.Parameter(torch.zeros(self.dim, 1))
|
| 48 |
+
if shape == 4:
|
| 49 |
+
self.weight = nn.Parameter(torch.ones(self.dim, 1, 1))
|
| 50 |
+
self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1))
|
| 51 |
+
else:
|
| 52 |
+
self.register_parameter("weight", None)
|
| 53 |
+
self.register_parameter("bias", None)
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
"""Returns the normalized tensor.
|
| 57 |
+
|
| 58 |
+
Arguments
|
| 59 |
+
---------
|
| 60 |
+
x : torch.Tensor
|
| 61 |
+
Tensor of size [N, C, K, S] or [N, C, L].
|
| 62 |
+
"""
|
| 63 |
+
# x = N x C x K x S or N x C x L
|
| 64 |
+
# N x 1 x 1
|
| 65 |
+
# cln: mean,var N x 1 x K x S
|
| 66 |
+
# gln: mean,var N x 1 x 1
|
| 67 |
+
if x.dim() == 3:
|
| 68 |
+
mean = torch.mean(x, (1, 2), keepdim=True)
|
| 69 |
+
var = torch.mean((x - mean) ** 2, (1, 2), keepdim=True)
|
| 70 |
+
if self.elementwise_affine:
|
| 71 |
+
x = (
|
| 72 |
+
self.weight * (x - mean) / torch.sqrt(var + self.eps)
|
| 73 |
+
+ self.bias
|
| 74 |
+
)
|
| 75 |
+
else:
|
| 76 |
+
x = (x - mean) / torch.sqrt(var + self.eps)
|
| 77 |
+
|
| 78 |
+
if x.dim() == 4:
|
| 79 |
+
mean = torch.mean(x, (1, 2, 3), keepdim=True)
|
| 80 |
+
var = torch.mean((x - mean) ** 2, (1, 2, 3), keepdim=True)
|
| 81 |
+
if self.elementwise_affine:
|
| 82 |
+
x = (
|
| 83 |
+
self.weight * (x - mean) / torch.sqrt(var + self.eps)
|
| 84 |
+
+ self.bias
|
| 85 |
+
)
|
| 86 |
+
else:
|
| 87 |
+
x = (x - mean) / torch.sqrt(var + self.eps)
|
| 88 |
+
return x
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class CumulativeLayerNorm(nn.LayerNorm):
|
| 92 |
+
"""Calculate Cumulative Layer Normalization.
|
| 93 |
+
|
| 94 |
+
Arguments
|
| 95 |
+
---------
|
| 96 |
+
dim : int
|
| 97 |
+
Dimension that you want to normalize.
|
| 98 |
+
elementwise_affine : True
|
| 99 |
+
Learnable per-element affine parameters.
|
| 100 |
+
|
| 101 |
+
Example
|
| 102 |
+
-------
|
| 103 |
+
>>> x = torch.randn(5, 10, 20)
|
| 104 |
+
>>> CLN = CumulativeLayerNorm(10)
|
| 105 |
+
>>> x_norm = CLN(x)
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
def __init__(self, dim, elementwise_affine=True):
|
| 109 |
+
super(CumulativeLayerNorm, self).__init__(
|
| 110 |
+
dim, elementwise_affine=elementwise_affine, eps=1e-8
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def forward(self, x):
|
| 114 |
+
"""Returns the normalized tensor.
|
| 115 |
+
|
| 116 |
+
Arguments
|
| 117 |
+
---------
|
| 118 |
+
x : torch.Tensor
|
| 119 |
+
Tensor size [N, C, K, S] or [N, C, L]
|
| 120 |
+
"""
|
| 121 |
+
# x: N x C x K x S or N x C x L
|
| 122 |
+
# N x K x S x C
|
| 123 |
+
if x.dim() == 4:
|
| 124 |
+
x = x.permute(0, 2, 3, 1).contiguous()
|
| 125 |
+
# N x K x S x C == only channel norm
|
| 126 |
+
x = super().forward(x)
|
| 127 |
+
# N x C x K x S
|
| 128 |
+
x = x.permute(0, 3, 1, 2).contiguous()
|
| 129 |
+
if x.dim() == 3:
|
| 130 |
+
x = torch.transpose(x, 1, 2)
|
| 131 |
+
# N x L x C == only channel norm
|
| 132 |
+
x = super().forward(x)
|
| 133 |
+
# N x C x L
|
| 134 |
+
x = torch.transpose(x, 1, 2)
|
| 135 |
+
return x
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def select_norm(norm, dim, shape):
|
| 139 |
+
"""Just a wrapper to select the normalization type.
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
if norm == "gln":
|
| 143 |
+
return GlobalLayerNorm(dim, shape, elementwise_affine=True)
|
| 144 |
+
if norm == "cln":
|
| 145 |
+
return CumulativeLayerNorm(dim, elementwise_affine=True)
|
| 146 |
+
if norm == "ln":
|
| 147 |
+
return nn.GroupNorm(1, dim, eps=1e-8)
|
| 148 |
+
else:
|
| 149 |
+
return nn.BatchNorm1d(dim)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class Encoder(nn.Module):
|
| 153 |
+
"""Convolutional Encoder Layer.
|
| 154 |
+
|
| 155 |
+
Arguments
|
| 156 |
+
---------
|
| 157 |
+
kernel_size : int
|
| 158 |
+
Length of filters.
|
| 159 |
+
in_channels : int
|
| 160 |
+
Number of input channels.
|
| 161 |
+
out_channels : int
|
| 162 |
+
Number of output channels.
|
| 163 |
+
|
| 164 |
+
Example
|
| 165 |
+
-------
|
| 166 |
+
>>> x = torch.randn(2, 1000)
|
| 167 |
+
>>> encoder = Encoder(kernel_size=4, out_channels=64)
|
| 168 |
+
>>> h = encoder(x)
|
| 169 |
+
>>> h.shape
|
| 170 |
+
torch.Size([2, 64, 499])
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
def __init__(self, kernel_size=2, out_channels=64, in_channels=1):
|
| 174 |
+
super(Encoder, self).__init__()
|
| 175 |
+
self.conv1d = nn.Conv1d(
|
| 176 |
+
in_channels=in_channels,
|
| 177 |
+
out_channels=out_channels,
|
| 178 |
+
kernel_size=kernel_size,
|
| 179 |
+
stride=kernel_size // 2,
|
| 180 |
+
groups=1,
|
| 181 |
+
bias=False,
|
| 182 |
+
)
|
| 183 |
+
self.in_channels = in_channels
|
| 184 |
+
|
| 185 |
+
def forward(self, x):
|
| 186 |
+
"""Return the encoded output.
|
| 187 |
+
|
| 188 |
+
Arguments
|
| 189 |
+
---------
|
| 190 |
+
x : torch.Tensor
|
| 191 |
+
Input tensor with dimensionality [B, L].
|
| 192 |
+
Return
|
| 193 |
+
------
|
| 194 |
+
x : torch.Tensor
|
| 195 |
+
Encoded tensor with dimensionality [B, N, T_out].
|
| 196 |
+
|
| 197 |
+
where B = Batchsize
|
| 198 |
+
L = Number of timepoints
|
| 199 |
+
N = Number of filters
|
| 200 |
+
T_out = Number of timepoints at the output of the encoder
|
| 201 |
+
"""
|
| 202 |
+
# B x L -> B x 1 x L
|
| 203 |
+
if self.in_channels == 1:
|
| 204 |
+
x = torch.unsqueeze(x, dim=1)
|
| 205 |
+
# B x 1 x L -> B x N x T_out
|
| 206 |
+
x = self.conv1d(x)
|
| 207 |
+
x = F.relu(x)
|
| 208 |
+
|
| 209 |
+
return x
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class Decoder(nn.ConvTranspose1d):
|
| 213 |
+
"""A decoder layer that consists of ConvTranspose1d.
|
| 214 |
+
|
| 215 |
+
Arguments
|
| 216 |
+
---------
|
| 217 |
+
kernel_size : int
|
| 218 |
+
Length of filters.
|
| 219 |
+
in_channels : int
|
| 220 |
+
Number of input channels.
|
| 221 |
+
out_channels : int
|
| 222 |
+
Number of output channels.
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
Example
|
| 226 |
+
---------
|
| 227 |
+
>>> x = torch.randn(2, 100, 1000)
|
| 228 |
+
>>> decoder = Decoder(kernel_size=4, in_channels=100, out_channels=1)
|
| 229 |
+
>>> h = decoder(x)
|
| 230 |
+
>>> h.shape
|
| 231 |
+
torch.Size([2, 1003])
|
| 232 |
+
"""
|
| 233 |
+
|
| 234 |
+
def __init__(self, *args, **kwargs):
|
| 235 |
+
super(Decoder, self).__init__(*args, **kwargs)
|
| 236 |
+
|
| 237 |
+
def forward(self, x):
|
| 238 |
+
"""Return the decoded output.
|
| 239 |
+
|
| 240 |
+
Arguments
|
| 241 |
+
---------
|
| 242 |
+
x : torch.Tensor
|
| 243 |
+
Input tensor with dimensionality [B, N, L].
|
| 244 |
+
where, B = Batchsize,
|
| 245 |
+
N = number of filters
|
| 246 |
+
L = time points
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
if x.dim() not in [2, 3]:
|
| 250 |
+
raise RuntimeError(
|
| 251 |
+
"{} accept 3/4D tensor as input".format(self.__name__)
|
| 252 |
+
)
|
| 253 |
+
x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
|
| 254 |
+
|
| 255 |
+
if torch.squeeze(x).dim() == 1:
|
| 256 |
+
x = torch.squeeze(x, dim=1)
|
| 257 |
+
else:
|
| 258 |
+
x = torch.squeeze(x)
|
| 259 |
+
return x
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class IdentityBlock:
|
| 263 |
+
"""This block is used when we want to have identity transformation within the Dual_path block.
|
| 264 |
+
|
| 265 |
+
Example
|
| 266 |
+
-------
|
| 267 |
+
>>> x = torch.randn(10, 100)
|
| 268 |
+
>>> IB = IdentityBlock()
|
| 269 |
+
>>> xhat = IB(x)
|
| 270 |
+
"""
|
| 271 |
+
|
| 272 |
+
def _init__(self, **kwargs):
|
| 273 |
+
pass
|
| 274 |
+
|
| 275 |
+
def __call__(self, x):
|
| 276 |
+
return x
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class MossFormerM(nn.Module):
|
| 280 |
+
"""This class implements the transformer encoder.
|
| 281 |
+
|
| 282 |
+
Arguments
|
| 283 |
+
---------
|
| 284 |
+
num_blocks : int
|
| 285 |
+
Number of mossformer blocks to include.
|
| 286 |
+
d_model : int
|
| 287 |
+
The dimension of the input embedding.
|
| 288 |
+
attn_dropout : float
|
| 289 |
+
Dropout for the self-attention (Optional).
|
| 290 |
+
group_size: int
|
| 291 |
+
the chunk size
|
| 292 |
+
query_key_dim: int
|
| 293 |
+
the attention vector dimension
|
| 294 |
+
expansion_factor: int
|
| 295 |
+
the expansion factor for the linear projection in conv module
|
| 296 |
+
causal: bool
|
| 297 |
+
true for causal / false for non causal
|
| 298 |
+
|
| 299 |
+
Example
|
| 300 |
+
-------
|
| 301 |
+
>>> import torch
|
| 302 |
+
>>> x = torch.rand((8, 60, 512))
|
| 303 |
+
>>> net = TransformerEncoder_MossFormerM(num_blocks=8, d_model=512)
|
| 304 |
+
>>> output, _ = net(x)
|
| 305 |
+
>>> output.shape
|
| 306 |
+
torch.Size([8, 60, 512])
|
| 307 |
+
"""
|
| 308 |
+
def __init__(
|
| 309 |
+
self,
|
| 310 |
+
num_blocks,
|
| 311 |
+
d_model=None,
|
| 312 |
+
causal=False,
|
| 313 |
+
group_size = 256,
|
| 314 |
+
query_key_dim = 128,
|
| 315 |
+
expansion_factor = 4.,
|
| 316 |
+
attn_dropout = 0.1
|
| 317 |
+
):
|
| 318 |
+
super().__init__()
|
| 319 |
+
|
| 320 |
+
self.mossformerM = MossformerBlock_GFSMN(
|
| 321 |
+
dim=d_model,
|
| 322 |
+
depth=num_blocks,
|
| 323 |
+
group_size=group_size,
|
| 324 |
+
query_key_dim=query_key_dim,
|
| 325 |
+
expansion_factor=expansion_factor,
|
| 326 |
+
causal=causal,
|
| 327 |
+
attn_dropout=attn_dropout
|
| 328 |
+
)
|
| 329 |
+
self.norm = nn.LayerNorm(d_model, eps=1e-6)
|
| 330 |
+
def forward(
|
| 331 |
+
self,
|
| 332 |
+
src,
|
| 333 |
+
):
|
| 334 |
+
"""
|
| 335 |
+
Arguments
|
| 336 |
+
----------
|
| 337 |
+
src : torch.Tensor
|
| 338 |
+
Tensor shape [B, L, N],
|
| 339 |
+
where, B = Batchsize,
|
| 340 |
+
L = time points
|
| 341 |
+
N = number of filters
|
| 342 |
+
The sequence to the encoder layer (required).
|
| 343 |
+
src_mask : tensor
|
| 344 |
+
The mask for the src sequence (optional).
|
| 345 |
+
src_key_padding_mask : tensor
|
| 346 |
+
The mask for the src keys per batch (optional).
|
| 347 |
+
"""
|
| 348 |
+
output = self.mossformerM(src)
|
| 349 |
+
output = self.norm(output)
|
| 350 |
+
|
| 351 |
+
return output
|
| 352 |
+
|
| 353 |
+
class MossFormerM2(nn.Module):
|
| 354 |
+
"""This class implements the transformer encoder.
|
| 355 |
+
|
| 356 |
+
Arguments
|
| 357 |
+
---------
|
| 358 |
+
num_blocks : int
|
| 359 |
+
Number of mossformer blocks to include.
|
| 360 |
+
d_model : int
|
| 361 |
+
The dimension of the input embedding.
|
| 362 |
+
attn_dropout : float
|
| 363 |
+
Dropout for the self-attention (Optional).
|
| 364 |
+
group_size: int
|
| 365 |
+
the chunk size
|
| 366 |
+
query_key_dim: int
|
| 367 |
+
the attention vector dimension
|
| 368 |
+
expansion_factor: int
|
| 369 |
+
the expansion factor for the linear projection in conv module
|
| 370 |
+
causal: bool
|
| 371 |
+
true for causal / false for non causal
|
| 372 |
+
|
| 373 |
+
Example
|
| 374 |
+
-------
|
| 375 |
+
>>> import torch
|
| 376 |
+
>>> x = torch.rand((8, 60, 512))
|
| 377 |
+
>>> net = TransformerEncoder_MossFormerM(num_blocks=8, d_model=512)
|
| 378 |
+
>>> output, _ = net(x)
|
| 379 |
+
>>> output.shape
|
| 380 |
+
torch.Size([8, 60, 512])
|
| 381 |
+
"""
|
| 382 |
+
def __init__(
|
| 383 |
+
self,
|
| 384 |
+
num_blocks,
|
| 385 |
+
d_model=None,
|
| 386 |
+
causal=False,
|
| 387 |
+
group_size = 256,
|
| 388 |
+
query_key_dim = 128,
|
| 389 |
+
expansion_factor = 4.,
|
| 390 |
+
attn_dropout = 0.1
|
| 391 |
+
):
|
| 392 |
+
super().__init__()
|
| 393 |
+
|
| 394 |
+
self.mossformerM = MossformerBlock(
|
| 395 |
+
dim=d_model,
|
| 396 |
+
depth=num_blocks,
|
| 397 |
+
group_size=group_size,
|
| 398 |
+
query_key_dim=query_key_dim,
|
| 399 |
+
expansion_factor=expansion_factor,
|
| 400 |
+
causal=causal,
|
| 401 |
+
attn_dropout=attn_dropout
|
| 402 |
+
)
|
| 403 |
+
self.norm = nn.LayerNorm(d_model, eps=1e-6)
|
| 404 |
+
|
| 405 |
+
def forward(
|
| 406 |
+
self,
|
| 407 |
+
src,
|
| 408 |
+
):
|
| 409 |
+
"""
|
| 410 |
+
Arguments
|
| 411 |
+
----------
|
| 412 |
+
src : torch.Tensor
|
| 413 |
+
Tensor shape [B, L, N],
|
| 414 |
+
where, B = Batchsize,
|
| 415 |
+
L = time points
|
| 416 |
+
N = number of filters
|
| 417 |
+
The sequence to the encoder layer (required).
|
| 418 |
+
src_mask : tensor
|
| 419 |
+
The mask for the src sequence (optional).
|
| 420 |
+
src_key_padding_mask : tensor
|
| 421 |
+
The mask for the src keys per batch (optional).
|
| 422 |
+
"""
|
| 423 |
+
output = self.mossformerM(src)
|
| 424 |
+
output = self.norm(output)
|
| 425 |
+
|
| 426 |
+
return output
|
| 427 |
+
|
| 428 |
+
class Computation_Block(nn.Module):
|
| 429 |
+
"""Computation block for dual-path processing.
|
| 430 |
+
|
| 431 |
+
Arguments
|
| 432 |
+
---------
|
| 433 |
+
intra_mdl : torch.nn.module
|
| 434 |
+
Model to process within the chunks.
|
| 435 |
+
inter_mdl : torch.nn.module
|
| 436 |
+
Model to process across the chunks.
|
| 437 |
+
out_channels : int
|
| 438 |
+
Dimensionality of inter/intra model.
|
| 439 |
+
norm : str
|
| 440 |
+
Normalization type.
|
| 441 |
+
skip_around_intra : bool
|
| 442 |
+
Skip connection around the intra layer.
|
| 443 |
+
linear_layer_after_inter_intra : bool
|
| 444 |
+
Linear layer or not after inter or intra.
|
| 445 |
+
|
| 446 |
+
Example
|
| 447 |
+
---------
|
| 448 |
+
>>> comp_block = Computation_Block(64)
|
| 449 |
+
>>> x = torch.randn(10, 64, 100)
|
| 450 |
+
>>> x = comp_block(x)
|
| 451 |
+
>>> x.shape
|
| 452 |
+
torch.Size([10, 64, 100])
|
| 453 |
+
"""
|
| 454 |
+
|
| 455 |
+
def __init__(
|
| 456 |
+
self,
|
| 457 |
+
num_blocks,
|
| 458 |
+
out_channels,
|
| 459 |
+
norm="ln",
|
| 460 |
+
skip_around_intra=True,
|
| 461 |
+
):
|
| 462 |
+
super(Computation_Block, self).__init__()
|
| 463 |
+
|
| 464 |
+
##MossFormer+: MossFormer with recurrence
|
| 465 |
+
self.intra_mdl = MossFormerM(num_blocks=num_blocks, d_model=out_channels)
|
| 466 |
+
##MossFormerM2: the orignal MossFormer
|
| 467 |
+
#self.intra_mdl = MossFormerM2(num_blocks=num_blocks, d_model=out_channels)
|
| 468 |
+
self.skip_around_intra = skip_around_intra
|
| 469 |
+
|
| 470 |
+
# Norm
|
| 471 |
+
self.norm = norm
|
| 472 |
+
if norm is not None:
|
| 473 |
+
self.intra_norm = select_norm(norm, out_channels, 3)
|
| 474 |
+
|
| 475 |
+
def forward(self, x):
|
| 476 |
+
"""Returns the output tensor.
|
| 477 |
+
|
| 478 |
+
Arguments
|
| 479 |
+
---------
|
| 480 |
+
x : torch.Tensor
|
| 481 |
+
Input tensor of dimension [B, N, S].
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
Return
|
| 485 |
+
---------
|
| 486 |
+
out: torch.Tensor
|
| 487 |
+
Output tensor of dimension [B, N, S].
|
| 488 |
+
where, B = Batchsize,
|
| 489 |
+
N = number of filters
|
| 490 |
+
S = sequence time index
|
| 491 |
+
"""
|
| 492 |
+
B, N, S = x.shape
|
| 493 |
+
# intra RNN
|
| 494 |
+
# [B, S, N]
|
| 495 |
+
intra = x.permute(0, 2, 1).contiguous() #.view(B, S, N)
|
| 496 |
+
|
| 497 |
+
intra = self.intra_mdl(intra)
|
| 498 |
+
|
| 499 |
+
# [B, N, S]
|
| 500 |
+
intra = intra.permute(0, 2, 1).contiguous()
|
| 501 |
+
if self.norm is not None:
|
| 502 |
+
intra = self.intra_norm(intra)
|
| 503 |
+
|
| 504 |
+
# [B, N, S]
|
| 505 |
+
if self.skip_around_intra:
|
| 506 |
+
intra = intra + x
|
| 507 |
+
|
| 508 |
+
out = intra
|
| 509 |
+
return out
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
class MossFormer_MaskNet(nn.Module):
|
| 513 |
+
"""The dual path model which is the basis for dualpathrnn, sepformer, dptnet.
|
| 514 |
+
|
| 515 |
+
Arguments
|
| 516 |
+
---------
|
| 517 |
+
in_channels : int
|
| 518 |
+
Number of channels at the output of the encoder.
|
| 519 |
+
out_channels : int
|
| 520 |
+
Number of channels that would be inputted to the intra and inter blocks.
|
| 521 |
+
intra_model : torch.nn.module
|
| 522 |
+
Model to process within the chunks.
|
| 523 |
+
num_layers : int
|
| 524 |
+
Number of layers of Dual Computation Block.
|
| 525 |
+
norm : str
|
| 526 |
+
Normalization type.
|
| 527 |
+
num_spks : int
|
| 528 |
+
Number of sources (speakers).
|
| 529 |
+
skip_around_intra : bool
|
| 530 |
+
Skip connection around intra.
|
| 531 |
+
use_global_pos_enc : bool
|
| 532 |
+
Global positional encodings.
|
| 533 |
+
max_length : int
|
| 534 |
+
Maximum sequence length.
|
| 535 |
+
|
| 536 |
+
Example
|
| 537 |
+
---------
|
| 538 |
+
>>> mossformer_block = MossFormerM(1, 64, 8)
|
| 539 |
+
>>> mossformer_masknet = MossFormer_MaskNet(64, 64, intra_block, num_spks=2)
|
| 540 |
+
>>> x = torch.randn(10, 64, 2000)
|
| 541 |
+
>>> x = mossformer_masknet(x)
|
| 542 |
+
>>> x.shape
|
| 543 |
+
torch.Size([2, 10, 64, 2000])
|
| 544 |
+
"""
|
| 545 |
+
|
| 546 |
+
def __init__(
|
| 547 |
+
self,
|
| 548 |
+
in_channels,
|
| 549 |
+
out_channels,
|
| 550 |
+
out_channels_final,
|
| 551 |
+
num_blocks=24,
|
| 552 |
+
norm="ln",
|
| 553 |
+
num_spks=1,
|
| 554 |
+
skip_around_intra=True,
|
| 555 |
+
use_global_pos_enc=True,
|
| 556 |
+
max_length=20000,
|
| 557 |
+
):
|
| 558 |
+
super(MossFormer_MaskNet, self).__init__()
|
| 559 |
+
self.num_spks = num_spks
|
| 560 |
+
self.num_blocks = num_blocks
|
| 561 |
+
self.norm = select_norm(norm, in_channels, 3)
|
| 562 |
+
self.conv1d_encoder = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
| 563 |
+
self.use_global_pos_enc = use_global_pos_enc
|
| 564 |
+
|
| 565 |
+
if self.use_global_pos_enc:
|
| 566 |
+
self.pos_enc = ScaledSinuEmbedding(out_channels)
|
| 567 |
+
|
| 568 |
+
self.mdl = Computation_Block(
|
| 569 |
+
num_blocks,
|
| 570 |
+
out_channels,
|
| 571 |
+
norm,
|
| 572 |
+
skip_around_intra=skip_around_intra,
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
self.conv1d_out = nn.Conv1d(
|
| 576 |
+
out_channels, out_channels * num_spks, kernel_size=1
|
| 577 |
+
)
|
| 578 |
+
self.conv1_decoder = nn.Conv1d(out_channels, out_channels_final, 1, bias=False)
|
| 579 |
+
self.prelu = nn.PReLU()
|
| 580 |
+
self.activation = nn.ReLU()
|
| 581 |
+
# gated output layer
|
| 582 |
+
self.output = nn.Sequential(
|
| 583 |
+
nn.Conv1d(out_channels, out_channels, 1), nn.Tanh()
|
| 584 |
+
)
|
| 585 |
+
self.output_gate = nn.Sequential(
|
| 586 |
+
nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid()
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
def forward(self, x):
|
| 590 |
+
"""Returns the output tensor.
|
| 591 |
+
|
| 592 |
+
Arguments
|
| 593 |
+
---------
|
| 594 |
+
x : torch.Tensor
|
| 595 |
+
Input tensor of dimension [B, N, S].
|
| 596 |
+
|
| 597 |
+
Returns
|
| 598 |
+
-------
|
| 599 |
+
out : torch.Tensor
|
| 600 |
+
Output tensor of dimension [spks, B, N, S]
|
| 601 |
+
where, spks = Number of speakers
|
| 602 |
+
B = Batchsize,
|
| 603 |
+
N = number of filters
|
| 604 |
+
S = the number of time frames
|
| 605 |
+
"""
|
| 606 |
+
|
| 607 |
+
# before each line we indicate the shape after executing the line
|
| 608 |
+
|
| 609 |
+
# [B, N, L]
|
| 610 |
+
x = self.norm(x)
|
| 611 |
+
|
| 612 |
+
# [B, N, L]
|
| 613 |
+
x = self.conv1d_encoder(x)
|
| 614 |
+
if self.use_global_pos_enc:
|
| 615 |
+
#x = self.pos_enc(x.transpose(1, -1)).transpose(1, -1) + x * (
|
| 616 |
+
# x.size(1) ** 0.5)
|
| 617 |
+
base = x
|
| 618 |
+
x = x.transpose(1, -1)
|
| 619 |
+
emb = self.pos_enc(x)
|
| 620 |
+
emb = emb.transpose(0, -1)
|
| 621 |
+
#print('base: {}, emb: {}'.format(base.shape, emb.shape))
|
| 622 |
+
x = base + emb
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
# [B, N, S]
|
| 626 |
+
#for i in range(self.num_modules):
|
| 627 |
+
# x = self.dual_mdl[i](x)
|
| 628 |
+
x = self.mdl(x)
|
| 629 |
+
x = self.prelu(x)
|
| 630 |
+
|
| 631 |
+
# [B, N*spks, S]
|
| 632 |
+
x = self.conv1d_out(x)
|
| 633 |
+
B, _, S = x.shape
|
| 634 |
+
|
| 635 |
+
# [B*spks, N, S]
|
| 636 |
+
x = x.view(B * self.num_spks, -1, S)
|
| 637 |
+
|
| 638 |
+
# [B*spks, N, S]
|
| 639 |
+
x = self.output(x) * self.output_gate(x)
|
| 640 |
+
|
| 641 |
+
# [B*spks, N, S]
|
| 642 |
+
x = self.conv1_decoder(x)
|
| 643 |
+
|
| 644 |
+
# [B, spks, N, S]
|
| 645 |
+
_, N, L = x.shape
|
| 646 |
+
x = x.view(B, self.num_spks, N, L)
|
| 647 |
+
x = self.activation(x)
|
| 648 |
+
|
| 649 |
+
# [spks, B, N, S]
|
| 650 |
+
x = x.transpose(0, 1)
|
| 651 |
+
|
| 652 |
+
return x[0]
|
| 653 |
+
|
| 654 |
+
class MossFormer(nn.Module):
|
| 655 |
+
def __init__(
|
| 656 |
+
self,
|
| 657 |
+
in_channels=512,
|
| 658 |
+
out_channels=512,
|
| 659 |
+
num_blocks=24,
|
| 660 |
+
kernel_size=16,
|
| 661 |
+
norm="ln",
|
| 662 |
+
num_spks=2,
|
| 663 |
+
skip_around_intra=True,
|
| 664 |
+
use_global_pos_enc=True,
|
| 665 |
+
max_length=20000,
|
| 666 |
+
):
|
| 667 |
+
super(MossFormer, self).__init__()
|
| 668 |
+
self.num_spks = num_spks
|
| 669 |
+
self.enc = Encoder(kernel_size=kernel_size, out_channels=in_channels, in_channels=180)
|
| 670 |
+
self.mask_net = MossFormer_MaskNet(
|
| 671 |
+
in_channels=in_channels,
|
| 672 |
+
out_channels=out_channels,
|
| 673 |
+
num_blocks=num_blocks,
|
| 674 |
+
norm=norm,
|
| 675 |
+
num_spks=num_spks,
|
| 676 |
+
skip_around_intra=skip_around_intra,
|
| 677 |
+
use_global_pos_enc=use_global_pos_enc,
|
| 678 |
+
max_length=max_length,
|
| 679 |
+
)
|
| 680 |
+
self.dec = Decoder(
|
| 681 |
+
in_channels=out_channels,
|
| 682 |
+
out_channels=1,
|
| 683 |
+
kernel_size=kernel_size,
|
| 684 |
+
stride = kernel_size//2,
|
| 685 |
+
bias=False
|
| 686 |
+
)
|
| 687 |
+
def forward(self, input):
|
| 688 |
+
x = self.enc(input)
|
| 689 |
+
mask = self.mask_net(x)
|
| 690 |
+
x = torch.stack([x] * self.num_spks)
|
| 691 |
+
sep_x = x * mask
|
| 692 |
+
|
| 693 |
+
# Decoding
|
| 694 |
+
est_source = torch.cat(
|
| 695 |
+
[
|
| 696 |
+
self.dec(sep_x[i]).unsqueeze(-1)
|
| 697 |
+
for i in range(self.num_spks)
|
| 698 |
+
],
|
| 699 |
+
dim=-1,
|
| 700 |
+
)
|
| 701 |
+
T_origin = input.size(1)
|
| 702 |
+
T_est = est_source.size(1)
|
| 703 |
+
if T_origin > T_est:
|
| 704 |
+
est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est))
|
| 705 |
+
else:
|
| 706 |
+
est_source = est_source[:, :T_origin, :]
|
| 707 |
+
|
| 708 |
+
out = []
|
| 709 |
+
for spk in range(self.num_spks):
|
| 710 |
+
out.append(est_source[:,:,spk])
|
| 711 |
+
return out
|
models/mossformer2_sr/mossformer2_block.py
ADDED
|
@@ -0,0 +1,735 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This source code is modified by Shengkui Zhao based on https://github.com/lucidrains/FLASH-pytorch
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch import nn, einsum
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
from rotary_embedding_torch import RotaryEmbedding
|
| 11 |
+
from models.mossformer2_se.conv_module import ConvModule, GLU, FFConvM_Dilated
|
| 12 |
+
from models.mossformer2_se.fsmn import UniDeepFsmn, UniDeepFsmn_dilated
|
| 13 |
+
from torchinfo import summary
|
| 14 |
+
from models.mossformer2_se.layer_norm import CLayerNorm, GLayerNorm, GlobLayerNorm, ILayerNorm
|
| 15 |
+
|
| 16 |
+
# Helper functions
|
| 17 |
+
|
| 18 |
+
def identity(t, *args, **kwargs):
|
| 19 |
+
"""
|
| 20 |
+
Returns the input tensor unchanged.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
t (torch.Tensor): Input tensor.
|
| 24 |
+
*args: Additional arguments (ignored).
|
| 25 |
+
**kwargs: Additional keyword arguments (ignored).
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
torch.Tensor: The input tensor.
|
| 29 |
+
"""
|
| 30 |
+
return t
|
| 31 |
+
|
| 32 |
+
def append_dims(x, num_dims):
|
| 33 |
+
"""
|
| 34 |
+
Adds additional dimensions to the input tensor.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
x (torch.Tensor): Input tensor.
|
| 38 |
+
num_dims (int): Number of dimensions to append.
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
torch.Tensor: Tensor with appended dimensions.
|
| 42 |
+
"""
|
| 43 |
+
if num_dims <= 0:
|
| 44 |
+
return x
|
| 45 |
+
return x.view(*x.shape, *((1,) * num_dims)) # Reshape to append dimensions
|
| 46 |
+
|
| 47 |
+
def exists(val):
|
| 48 |
+
"""
|
| 49 |
+
Checks if a value exists (is not None).
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
val: The value to check.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
bool: True if value exists, False otherwise.
|
| 56 |
+
"""
|
| 57 |
+
return val is not None
|
| 58 |
+
|
| 59 |
+
def default(val, d):
|
| 60 |
+
"""
|
| 61 |
+
Returns a default value if the given value does not exist.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
val: The value to check.
|
| 65 |
+
d: Default value to return if val does not exist.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
The original value if it exists, otherwise the default value.
|
| 69 |
+
"""
|
| 70 |
+
return val if exists(val) else d
|
| 71 |
+
|
| 72 |
+
def padding_to_multiple_of(n, mult):
|
| 73 |
+
"""
|
| 74 |
+
Calculates the amount of padding needed to make a number a multiple of another.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
n (int): The number to pad.
|
| 78 |
+
mult (int): The multiple to match.
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
int: The padding amount required to make n a multiple of mult.
|
| 82 |
+
"""
|
| 83 |
+
remainder = n % mult
|
| 84 |
+
if remainder == 0:
|
| 85 |
+
return 0
|
| 86 |
+
return mult - remainder # Return the required padding
|
| 87 |
+
|
| 88 |
+
# Scale Normalization class
|
| 89 |
+
|
| 90 |
+
class ScaleNorm(nn.Module):
|
| 91 |
+
"""
|
| 92 |
+
ScaleNorm implements a scaled normalization technique for neural network layers.
|
| 93 |
+
|
| 94 |
+
Attributes:
|
| 95 |
+
dim (int): Dimension of the input features.
|
| 96 |
+
eps (float): Small value to prevent division by zero.
|
| 97 |
+
g (nn.Parameter): Learnable parameter for scaling.
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def __init__(self, dim, eps=1e-5):
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.scale = dim ** -0.5 # Calculate scale factor
|
| 103 |
+
self.eps = eps # Set epsilon
|
| 104 |
+
self.g = nn.Parameter(torch.ones(1)) # Initialize scaling parameter
|
| 105 |
+
|
| 106 |
+
def forward(self, x):
|
| 107 |
+
"""
|
| 108 |
+
Forward pass for the ScaleNorm layer.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
x (torch.Tensor): Input tensor.
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
torch.Tensor: Scaled and normalized output tensor.
|
| 115 |
+
"""
|
| 116 |
+
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale # Compute norm
|
| 117 |
+
return x / norm.clamp(min=self.eps) * self.g # Normalize and scale
|
| 118 |
+
|
| 119 |
+
# Absolute positional encodings class
|
| 120 |
+
|
| 121 |
+
class ScaledSinuEmbedding(nn.Module):
|
| 122 |
+
"""
|
| 123 |
+
ScaledSinuEmbedding provides sinusoidal positional encodings for inputs.
|
| 124 |
+
|
| 125 |
+
Attributes:
|
| 126 |
+
scale (nn.Parameter): Learnable scale factor for the embeddings.
|
| 127 |
+
inv_freq (torch.Tensor): Inverse frequency used for sine and cosine calculations.
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
def __init__(self, dim):
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.scale = nn.Parameter(torch.ones(1,)) # Initialize scale
|
| 133 |
+
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) # Calculate inverse frequency
|
| 134 |
+
self.register_buffer('inv_freq', inv_freq) # Register as a buffer
|
| 135 |
+
|
| 136 |
+
def forward(self, x):
|
| 137 |
+
"""
|
| 138 |
+
Forward pass for the ScaledSinuEmbedding layer.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
x (torch.Tensor): Input tensor of shape (batch_size, sequence_length).
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
torch.Tensor: Positional encoding tensor of shape (batch_size, sequence_length, dim).
|
| 145 |
+
"""
|
| 146 |
+
n, device = x.shape[1], x.device # Extract sequence length and device
|
| 147 |
+
t = torch.arange(n, device=device).type_as(self.inv_freq) # Create time steps
|
| 148 |
+
sinu = einsum('i , j -> i j', t, self.inv_freq) # Calculate sine and cosine embeddings
|
| 149 |
+
emb = torch.cat((sinu.sin(), sinu.cos()), dim=-1) # Concatenate sine and cosine embeddings
|
| 150 |
+
return emb * self.scale # Scale the embeddings
|
| 151 |
+
|
| 152 |
+
class OffsetScale(nn.Module):
|
| 153 |
+
"""
|
| 154 |
+
OffsetScale applies learned offsets and scales to the input tensor.
|
| 155 |
+
|
| 156 |
+
Attributes:
|
| 157 |
+
gamma (nn.Parameter): Learnable scale parameter for each head.
|
| 158 |
+
beta (nn.Parameter): Learnable offset parameter for each head.
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
def __init__(self, dim, heads=1):
|
| 162 |
+
super().__init__()
|
| 163 |
+
self.gamma = nn.Parameter(torch.ones(heads, dim)) # Initialize scale parameters
|
| 164 |
+
self.beta = nn.Parameter(torch.zeros(heads, dim)) # Initialize offset parameters
|
| 165 |
+
nn.init.normal_(self.gamma, std=0.02) # Normal initialization for gamma
|
| 166 |
+
|
| 167 |
+
def forward(self, x):
|
| 168 |
+
"""
|
| 169 |
+
Forward pass for the OffsetScale layer.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
x (torch.Tensor): Input tensor.
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
List[torch.Tensor]: A list of tensors with applied offsets and scales for each head.
|
| 176 |
+
"""
|
| 177 |
+
out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta # Apply scaling and offsets
|
| 178 |
+
return out.unbind(dim=-2) # Unbind heads into a list
|
| 179 |
+
|
| 180 |
+
# Feed-Forward Convolutional Module
|
| 181 |
+
|
| 182 |
+
class FFConvM(nn.Module):
|
| 183 |
+
"""
|
| 184 |
+
FFConvM is a feed-forward convolutional module with normalization and dropout.
|
| 185 |
+
|
| 186 |
+
Attributes:
|
| 187 |
+
dim_in (int): Input dimension of the features.
|
| 188 |
+
dim_out (int): Output dimension after processing.
|
| 189 |
+
norm_klass (nn.Module): Normalization class to be used.
|
| 190 |
+
dropout (float): Dropout probability.
|
| 191 |
+
"""
|
| 192 |
+
|
| 193 |
+
def __init__(
|
| 194 |
+
self,
|
| 195 |
+
dim_in,
|
| 196 |
+
dim_out,
|
| 197 |
+
norm_klass=nn.LayerNorm,
|
| 198 |
+
dropout=0.1
|
| 199 |
+
):
|
| 200 |
+
super().__init__()
|
| 201 |
+
self.mdl = nn.Sequential(
|
| 202 |
+
norm_klass(dim_in), # Normalize input
|
| 203 |
+
nn.Linear(dim_in, dim_out), # Linear transformation
|
| 204 |
+
nn.SiLU(), # Activation function
|
| 205 |
+
ConvModule(dim_out), # Convolution module
|
| 206 |
+
nn.Dropout(dropout) # Apply dropout
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
def forward(self, x):
|
| 210 |
+
"""
|
| 211 |
+
Forward pass for the FFConvM module.
|
| 212 |
+
|
| 213 |
+
Args:
|
| 214 |
+
x (torch.Tensor): Input tensor.
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
torch.Tensor: Output tensor after processing.
|
| 218 |
+
"""
|
| 219 |
+
output = self.mdl(x) # Pass through the model
|
| 220 |
+
return output
|
| 221 |
+
|
| 222 |
+
class FFM(nn.Module):
|
| 223 |
+
"""
|
| 224 |
+
FFM is a feed-forward module with normalization and dropout.
|
| 225 |
+
|
| 226 |
+
Attributes:
|
| 227 |
+
dim_in (int): Input dimension of the features.
|
| 228 |
+
dim_out (int): Output dimension after processing.
|
| 229 |
+
norm_klass (nn.Module): Normalization class to be used.
|
| 230 |
+
dropout (float): Dropout probability.
|
| 231 |
+
"""
|
| 232 |
+
|
| 233 |
+
def __init__(
|
| 234 |
+
self,
|
| 235 |
+
dim_in,
|
| 236 |
+
dim_out,
|
| 237 |
+
norm_klass=nn.LayerNorm,
|
| 238 |
+
dropout=0.1
|
| 239 |
+
):
|
| 240 |
+
super().__init__()
|
| 241 |
+
self.mdl = nn.Sequential(
|
| 242 |
+
norm_klass(dim_in), # Normalize input
|
| 243 |
+
nn.Linear(dim_in, dim_out), # Linear transformation
|
| 244 |
+
nn.SiLU(), # Activation function
|
| 245 |
+
nn.Dropout(dropout) # Apply dropout
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
def forward(self, x):
|
| 249 |
+
"""
|
| 250 |
+
Forward pass for the FFM module.
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
x (torch.Tensor): Input tensor.
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
torch.Tensor: Output tensor after processing.
|
| 257 |
+
"""
|
| 258 |
+
output = self.mdl(x) # Pass through the model
|
| 259 |
+
return output
|
| 260 |
+
|
| 261 |
+
class FLASH_ShareA_FFConvM(nn.Module):
|
| 262 |
+
"""
|
| 263 |
+
Fast Shared Dual Attention Mechanism with feed-forward convolutional blocks.
|
| 264 |
+
Published in paper: "MossFormer: Pushing the Performance Limit of Monaural Speech Separation
|
| 265 |
+
using Gated Single-Head Transformer with Convolution-Augmented Joint Self-Attentions", ICASSP 2023.
|
| 266 |
+
(https://arxiv.org/abs/2302.11824)
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
dim (int): Input dimension.
|
| 270 |
+
group_size (int, optional): Size of groups for processing. Defaults to 256.
|
| 271 |
+
query_key_dim (int, optional): Dimension of the query and key. Defaults to 128.
|
| 272 |
+
expansion_factor (float, optional): Factor to expand the hidden dimension. Defaults to 1.
|
| 273 |
+
causal (bool, optional): Whether to use causal masking. Defaults to False.
|
| 274 |
+
dropout (float, optional): Dropout rate. Defaults to 0.1.
|
| 275 |
+
rotary_pos_emb (optional): Rotary positional embeddings for attention. Defaults to None.
|
| 276 |
+
norm_klass (callable, optional): Normalization class to use. Defaults to nn.LayerNorm.
|
| 277 |
+
shift_tokens (bool, optional): Whether to shift tokens for attention calculation. Defaults to True.
|
| 278 |
+
"""
|
| 279 |
+
|
| 280 |
+
def __init__(
|
| 281 |
+
self,
|
| 282 |
+
*,
|
| 283 |
+
dim,
|
| 284 |
+
group_size=256,
|
| 285 |
+
query_key_dim=128,
|
| 286 |
+
expansion_factor=1.,
|
| 287 |
+
causal=False,
|
| 288 |
+
dropout=0.1,
|
| 289 |
+
rotary_pos_emb=None,
|
| 290 |
+
norm_klass=nn.LayerNorm,
|
| 291 |
+
shift_tokens=True
|
| 292 |
+
):
|
| 293 |
+
super().__init__()
|
| 294 |
+
hidden_dim = int(dim * expansion_factor)
|
| 295 |
+
self.group_size = group_size
|
| 296 |
+
self.causal = causal
|
| 297 |
+
self.shift_tokens = shift_tokens
|
| 298 |
+
|
| 299 |
+
# Initialize positional embeddings, dropout, and projections
|
| 300 |
+
self.rotary_pos_emb = rotary_pos_emb
|
| 301 |
+
self.dropout = nn.Dropout(dropout)
|
| 302 |
+
|
| 303 |
+
# Feed-forward layers
|
| 304 |
+
self.to_hidden = FFConvM(
|
| 305 |
+
dim_in=dim,
|
| 306 |
+
dim_out=hidden_dim,
|
| 307 |
+
norm_klass=norm_klass,
|
| 308 |
+
dropout=dropout,
|
| 309 |
+
)
|
| 310 |
+
self.to_qk = FFConvM(
|
| 311 |
+
dim_in=dim,
|
| 312 |
+
dim_out=query_key_dim,
|
| 313 |
+
norm_klass=norm_klass,
|
| 314 |
+
dropout=dropout,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
# Offset and scale for query and key
|
| 318 |
+
self.qk_offset_scale = OffsetScale(query_key_dim, heads=4)
|
| 319 |
+
|
| 320 |
+
self.to_out = FFConvM(
|
| 321 |
+
dim_in=dim * 2,
|
| 322 |
+
dim_out=dim,
|
| 323 |
+
norm_klass=norm_klass,
|
| 324 |
+
dropout=dropout,
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
self.gateActivate = nn.Sigmoid()
|
| 328 |
+
|
| 329 |
+
def forward(self, x, *, mask=None):
|
| 330 |
+
"""
|
| 331 |
+
Forward pass for FLASH layer.
|
| 332 |
+
|
| 333 |
+
Args:
|
| 334 |
+
x (Tensor): Input tensor of shape (batch, seq_len, features).
|
| 335 |
+
mask (Tensor, optional): Mask for attention. Defaults to None.
|
| 336 |
+
|
| 337 |
+
Returns:
|
| 338 |
+
Tensor: Output tensor after applying attention and projections.
|
| 339 |
+
"""
|
| 340 |
+
|
| 341 |
+
# Pre-normalization step
|
| 342 |
+
normed_x = x
|
| 343 |
+
residual = x # Save residual for skip connection
|
| 344 |
+
|
| 345 |
+
# Token shifting if enabled
|
| 346 |
+
if self.shift_tokens:
|
| 347 |
+
x_shift, x_pass = normed_x.chunk(2, dim=-1)
|
| 348 |
+
x_shift = F.pad(x_shift, (0, 0, 1, -1), value=0.)
|
| 349 |
+
normed_x = torch.cat((x_shift, x_pass), dim=-1)
|
| 350 |
+
|
| 351 |
+
# Initial projections
|
| 352 |
+
v, u = self.to_hidden(normed_x).chunk(2, dim=-1)
|
| 353 |
+
qk = self.to_qk(normed_x)
|
| 354 |
+
|
| 355 |
+
# Offset and scale
|
| 356 |
+
quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk)
|
| 357 |
+
att_v, att_u = self.cal_attention(x, quad_q, lin_q, quad_k, lin_k, v, u)
|
| 358 |
+
|
| 359 |
+
# Output calculation with gating
|
| 360 |
+
out = (att_u * v) * self.gateActivate(att_v * u)
|
| 361 |
+
x = x + self.to_out(out) # Residual connection
|
| 362 |
+
return x
|
| 363 |
+
|
| 364 |
+
def cal_attention(self, x, quad_q, lin_q, quad_k, lin_k, v, u, mask=None):
|
| 365 |
+
"""
|
| 366 |
+
Calculate attention output using quadratic and linear attention mechanisms.
|
| 367 |
+
|
| 368 |
+
Args:
|
| 369 |
+
x (Tensor): Input tensor of shape (batch, seq_len, features).
|
| 370 |
+
quad_q (Tensor): Quadratic query representation.
|
| 371 |
+
lin_q (Tensor): Linear query representation.
|
| 372 |
+
quad_k (Tensor): Quadratic key representation.
|
| 373 |
+
lin_k (Tensor): Linear key representation.
|
| 374 |
+
v (Tensor): Value representation.
|
| 375 |
+
u (Tensor): Additional value representation.
|
| 376 |
+
mask (Tensor, optional): Mask for attention. Defaults to None.
|
| 377 |
+
|
| 378 |
+
Returns:
|
| 379 |
+
Tuple[Tensor, Tensor]: Attention outputs for v and u.
|
| 380 |
+
"""
|
| 381 |
+
b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size
|
| 382 |
+
|
| 383 |
+
# Apply mask to linear keys if provided
|
| 384 |
+
if exists(mask):
|
| 385 |
+
lin_mask = rearrange(mask, '... -> ... 1')
|
| 386 |
+
lin_k = lin_k.masked_fill(~lin_mask, 0.)
|
| 387 |
+
|
| 388 |
+
# Rotate queries and keys with rotary positional embeddings
|
| 389 |
+
if exists(self.rotary_pos_emb):
|
| 390 |
+
quad_q, lin_q, quad_k, lin_k = map(self.rotary_pos_emb.rotate_queries_or_keys, (quad_q, lin_q, quad_k, lin_k))
|
| 391 |
+
|
| 392 |
+
# Padding for group processing
|
| 393 |
+
padding = padding_to_multiple_of(n, g)
|
| 394 |
+
if padding > 0:
|
| 395 |
+
quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: F.pad(t, (0, 0, 0, padding), value=0.), (quad_q, quad_k, lin_q, lin_k, v, u))
|
| 396 |
+
mask = default(mask, torch.ones((b, n), device=device, dtype=torch.bool))
|
| 397 |
+
mask = F.pad(mask, (0, padding), value=False)
|
| 398 |
+
|
| 399 |
+
# Group along sequence for attention
|
| 400 |
+
quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: rearrange(t, 'b (g n) d -> b g n d', n=self.group_size), (quad_q, quad_k, lin_q, lin_k, v, u))
|
| 401 |
+
|
| 402 |
+
if exists(mask):
|
| 403 |
+
mask = rearrange(mask, 'b (g j) -> b g 1 j', j=g)
|
| 404 |
+
|
| 405 |
+
# Calculate quadratic attention output
|
| 406 |
+
sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / g
|
| 407 |
+
attn = F.relu(sim) ** 2 # ReLU activation
|
| 408 |
+
attn = self.dropout(attn)
|
| 409 |
+
|
| 410 |
+
# Apply mask to attention if provided
|
| 411 |
+
if exists(mask):
|
| 412 |
+
attn = attn.masked_fill(~mask, 0.)
|
| 413 |
+
|
| 414 |
+
# Apply causal mask if needed
|
| 415 |
+
if self.causal:
|
| 416 |
+
causal_mask = torch.ones((g, g), dtype=torch.bool, device=device).triu(1)
|
| 417 |
+
attn = attn.masked_fill(causal_mask, 0.)
|
| 418 |
+
|
| 419 |
+
# Calculate output from attention
|
| 420 |
+
quad_out_v = einsum('... i j, ... j d -> ... i d', attn, v)
|
| 421 |
+
quad_out_u = einsum('... i j, ... j d -> ... i d', attn, u)
|
| 422 |
+
|
| 423 |
+
# Calculate linear attention output
|
| 424 |
+
if self.causal:
|
| 425 |
+
lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / g
|
| 426 |
+
lin_kv = lin_kv.cumsum(dim=1) # Cumulative sum for linear attention
|
| 427 |
+
lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value=0.)
|
| 428 |
+
lin_out_v = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q)
|
| 429 |
+
|
| 430 |
+
lin_ku = einsum('b g n d, b g n e -> b g d e', lin_k, u) / g
|
| 431 |
+
lin_ku = lin_ku.cumsum(dim=1) # Cumulative sum for linear attention
|
| 432 |
+
lin_ku = F.pad(lin_ku, (0, 0, 0, 0, 1, -1), value=0.)
|
| 433 |
+
lin_out_u = einsum('b g d e, b g n d -> b g n e', lin_ku, lin_q)
|
| 434 |
+
else:
|
| 435 |
+
lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n
|
| 436 |
+
lin_out_v = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv)
|
| 437 |
+
|
| 438 |
+
lin_ku = einsum('b g n d, b g n e -> b d e', lin_k, u) / n
|
| 439 |
+
lin_out_u = einsum('b g n d, b d e -> b g n e', lin_q, lin_ku)
|
| 440 |
+
|
| 441 |
+
# Reshape and remove padding from outputs
|
| 442 |
+
return map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out_v + lin_out_v, quad_out_u + lin_out_u))
|
| 443 |
+
|
| 444 |
+
class Gated_FSMN(nn.Module):
|
| 445 |
+
"""
|
| 446 |
+
Gated Frequency Selective Memory Network (FSMN) class.
|
| 447 |
+
|
| 448 |
+
This class implements a gated FSMN that combines two feedforward
|
| 449 |
+
convolutional networks with a frequency selective memory module.
|
| 450 |
+
|
| 451 |
+
Args:
|
| 452 |
+
in_channels (int): Number of input channels.
|
| 453 |
+
out_channels (int): Number of output channels.
|
| 454 |
+
lorder (int): Order of the filter for FSMN.
|
| 455 |
+
hidden_size (int): Number of hidden units in the network.
|
| 456 |
+
"""
|
| 457 |
+
def __init__(self, in_channels, out_channels, lorder, hidden_size):
|
| 458 |
+
super().__init__()
|
| 459 |
+
# Feedforward network for the first branch (u)
|
| 460 |
+
self.to_u = FFConvM(
|
| 461 |
+
dim_in=in_channels,
|
| 462 |
+
dim_out=hidden_size,
|
| 463 |
+
norm_klass=nn.LayerNorm,
|
| 464 |
+
dropout=0.1,
|
| 465 |
+
)
|
| 466 |
+
# Feedforward network for the second branch (v)
|
| 467 |
+
self.to_v = FFConvM(
|
| 468 |
+
dim_in=in_channels,
|
| 469 |
+
dim_out=hidden_size,
|
| 470 |
+
norm_klass=nn.LayerNorm,
|
| 471 |
+
dropout=0.1,
|
| 472 |
+
)
|
| 473 |
+
# Frequency selective memory network
|
| 474 |
+
self.fsmn = UniDeepFsmn(in_channels, out_channels, lorder, hidden_size)
|
| 475 |
+
|
| 476 |
+
def forward(self, x):
|
| 477 |
+
"""
|
| 478 |
+
Forward pass for the Gated FSMN.
|
| 479 |
+
|
| 480 |
+
Args:
|
| 481 |
+
x (Tensor): Input tensor of shape (batch_size, in_channels, sequence_length).
|
| 482 |
+
|
| 483 |
+
Returns:
|
| 484 |
+
Tensor: Output tensor after applying gated FSMN operations.
|
| 485 |
+
"""
|
| 486 |
+
input = x
|
| 487 |
+
x_u = self.to_u(x) # Process input through the first branch
|
| 488 |
+
x_v = self.to_v(x) # Process input through the second branch
|
| 489 |
+
x_u = self.fsmn(x_u) # Apply FSMN to the output of the first branch
|
| 490 |
+
x = x_v * x_u + input # Combine outputs with the original input
|
| 491 |
+
return x
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
class Gated_FSMN_Block(nn.Module):
|
| 495 |
+
"""
|
| 496 |
+
A 1-D convolutional block that incorporates a gated FSMN.
|
| 497 |
+
|
| 498 |
+
This block consists of two convolutional layers, followed by a
|
| 499 |
+
gated FSMN and normalization layers.
|
| 500 |
+
|
| 501 |
+
Args:
|
| 502 |
+
dim (int): Dimensionality of the input.
|
| 503 |
+
inner_channels (int): Number of channels in the inner layers.
|
| 504 |
+
group_size (int): Size of the groups for normalization.
|
| 505 |
+
norm_type (str): Type of normalization to use ('scalenorm' or 'layernorm').
|
| 506 |
+
"""
|
| 507 |
+
def __init__(self, dim, inner_channels=256, group_size=256, norm_type='scalenorm'):
|
| 508 |
+
super(Gated_FSMN_Block, self).__init__()
|
| 509 |
+
# Choose normalization class based on the provided type
|
| 510 |
+
if norm_type == 'scalenorm':
|
| 511 |
+
norm_klass = ScaleNorm
|
| 512 |
+
elif norm_type == 'layernorm':
|
| 513 |
+
norm_klass = nn.LayerNorm
|
| 514 |
+
|
| 515 |
+
self.group_size = group_size
|
| 516 |
+
|
| 517 |
+
# First convolutional layer with PReLU activation
|
| 518 |
+
self.conv1 = nn.Sequential(
|
| 519 |
+
nn.Conv1d(dim, inner_channels, kernel_size=1),
|
| 520 |
+
nn.PReLU(),
|
| 521 |
+
)
|
| 522 |
+
self.norm1 = CLayerNorm(inner_channels) # Normalization after first convolution
|
| 523 |
+
self.gated_fsmn = Gated_FSMN(inner_channels, inner_channels, lorder=20, hidden_size=inner_channels) # Gated FSMN layer
|
| 524 |
+
self.norm2 = CLayerNorm(inner_channels) # Normalization after FSMN
|
| 525 |
+
self.conv2 = nn.Conv1d(inner_channels, dim, kernel_size=1) # Final convolutional layer
|
| 526 |
+
|
| 527 |
+
def forward(self, input):
|
| 528 |
+
"""
|
| 529 |
+
Forward pass for the Gated FSMN Block.
|
| 530 |
+
|
| 531 |
+
Args:
|
| 532 |
+
input (Tensor): Input tensor of shape (batch_size, dim, sequence_length).
|
| 533 |
+
|
| 534 |
+
Returns:
|
| 535 |
+
Tensor: Output tensor after processing through the block.
|
| 536 |
+
"""
|
| 537 |
+
conv1 = self.conv1(input.transpose(2, 1)) # Apply first convolution
|
| 538 |
+
norm1 = self.norm1(conv1) # Apply normalization
|
| 539 |
+
seq_out = self.gated_fsmn(norm1.transpose(2, 1)) # Apply gated FSMN
|
| 540 |
+
norm2 = self.norm2(seq_out.transpose(2, 1)) # Apply second normalization
|
| 541 |
+
conv2 = self.conv2(norm2) # Apply final convolution
|
| 542 |
+
return conv2.transpose(2, 1) + input # Residual connection
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
class MossformerBlock_GFSMN(nn.Module):
|
| 546 |
+
"""
|
| 547 |
+
Mossformer Block with Gated FSMN.
|
| 548 |
+
|
| 549 |
+
This block combines attention mechanisms and gated FSMN layers
|
| 550 |
+
to process input sequences.
|
| 551 |
+
|
| 552 |
+
Args:
|
| 553 |
+
dim (int): Dimensionality of the input.
|
| 554 |
+
depth (int): Number of layers in the block.
|
| 555 |
+
group_size (int): Size of the groups for normalization.
|
| 556 |
+
query_key_dim (int): Dimension of the query and key in attention.
|
| 557 |
+
expansion_factor (float): Expansion factor for feedforward layers.
|
| 558 |
+
causal (bool): If True, enables causal attention.
|
| 559 |
+
attn_dropout (float): Dropout rate for attention layers.
|
| 560 |
+
norm_type (str): Type of normalization to use ('scalenorm' or 'layernorm').
|
| 561 |
+
shift_tokens (bool): If True, shifts tokens in the attention layer.
|
| 562 |
+
"""
|
| 563 |
+
def __init__(self, *, dim, depth, group_size=256, query_key_dim=128, expansion_factor=4., causal=False, attn_dropout=0.1, norm_type='scalenorm', shift_tokens=True):
|
| 564 |
+
super().__init__()
|
| 565 |
+
assert norm_type in ('scalenorm', 'layernorm'), 'norm_type must be one of scalenorm or layernorm'
|
| 566 |
+
|
| 567 |
+
if norm_type == 'scalenorm':
|
| 568 |
+
norm_klass = ScaleNorm
|
| 569 |
+
elif norm_type == 'layernorm':
|
| 570 |
+
norm_klass = nn.LayerNorm
|
| 571 |
+
|
| 572 |
+
self.group_size = group_size
|
| 573 |
+
|
| 574 |
+
# Rotary positional embedding for attention
|
| 575 |
+
rotary_pos_emb = RotaryEmbedding(dim=min(32, query_key_dim))
|
| 576 |
+
|
| 577 |
+
# Create a list of Gated FSMN blocks
|
| 578 |
+
self.fsmn = nn.ModuleList([Gated_FSMN_Block(dim) for _ in range(depth)])
|
| 579 |
+
|
| 580 |
+
# Create a list of attention layers using FLASH_ShareA_FFConvM
|
| 581 |
+
self.layers = nn.ModuleList([
|
| 582 |
+
FLASH_ShareA_FFConvM(
|
| 583 |
+
dim=dim,
|
| 584 |
+
group_size=group_size,
|
| 585 |
+
query_key_dim=query_key_dim,
|
| 586 |
+
expansion_factor=expansion_factor,
|
| 587 |
+
causal=causal,
|
| 588 |
+
dropout=attn_dropout,
|
| 589 |
+
rotary_pos_emb=rotary_pos_emb,
|
| 590 |
+
norm_klass=norm_klass,
|
| 591 |
+
shift_tokens=shift_tokens
|
| 592 |
+
) for _ in range(depth)
|
| 593 |
+
])
|
| 594 |
+
|
| 595 |
+
def _build_repeats(self, in_channels, out_channels, lorder, hidden_size, repeats=1):
|
| 596 |
+
"""
|
| 597 |
+
Builds repeated UniDeep FSMN layers.
|
| 598 |
+
|
| 599 |
+
Args:
|
| 600 |
+
in_channels (int): Number of input channels.
|
| 601 |
+
out_channels (int): Number of output channels.
|
| 602 |
+
lorder (int): Order of the filter for FSMN.
|
| 603 |
+
hidden_size (int): Number of hidden units.
|
| 604 |
+
repeats (int): Number of repetitions.
|
| 605 |
+
|
| 606 |
+
Returns:
|
| 607 |
+
Sequential: A sequential container with repeated layers.
|
| 608 |
+
"""
|
| 609 |
+
repeats = [
|
| 610 |
+
UniDeepFsmn(in_channels, out_channels, lorder, hidden_size)
|
| 611 |
+
for i in range(repeats)
|
| 612 |
+
]
|
| 613 |
+
return nn.Sequential(*repeats)
|
| 614 |
+
|
| 615 |
+
def forward(self, x, *, mask=None):
|
| 616 |
+
"""
|
| 617 |
+
Forward pass for the Mossformer Block with Gated FSMN.
|
| 618 |
+
|
| 619 |
+
Args:
|
| 620 |
+
x (Tensor): Input tensor of shape (batch_size, dim, sequence_length).
|
| 621 |
+
mask (Tensor, optional): Mask tensor for attention operations.
|
| 622 |
+
|
| 623 |
+
Returns:
|
| 624 |
+
Tensor: Output tensor after processing through the block.
|
| 625 |
+
"""
|
| 626 |
+
ii = 0
|
| 627 |
+
for flash in self.layers: # Process through each layer
|
| 628 |
+
x = flash(x, mask=mask)
|
| 629 |
+
x = self.fsmn[ii](x) # Apply corresponding Gated FSMN block
|
| 630 |
+
ii += 1
|
| 631 |
+
|
| 632 |
+
return x
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
class MossformerBlock(nn.Module):
|
| 636 |
+
"""
|
| 637 |
+
Mossformer Block with attention mechanisms.
|
| 638 |
+
|
| 639 |
+
This block is designed to process input sequences using attention
|
| 640 |
+
layers and incorporates rotary positional embeddings. It allows
|
| 641 |
+
for configurable normalization types and can handle causal
|
| 642 |
+
attention.
|
| 643 |
+
|
| 644 |
+
Args:
|
| 645 |
+
dim (int): Dimensionality of the input.
|
| 646 |
+
depth (int): Number of attention layers in the block.
|
| 647 |
+
group_size (int, optional): Size of groups for normalization. Default is 256.
|
| 648 |
+
query_key_dim (int, optional): Dimension of the query and key in attention. Default is 128.
|
| 649 |
+
expansion_factor (float, optional): Expansion factor for feedforward layers. Default is 4.
|
| 650 |
+
causal (bool, optional): If True, enables causal attention. Default is False.
|
| 651 |
+
attn_dropout (float, optional): Dropout rate for attention layers. Default is 0.1.
|
| 652 |
+
norm_type (str, optional): Type of normalization to use ('scalenorm' or 'layernorm'). Default is 'scalenorm'.
|
| 653 |
+
shift_tokens (bool, optional): If True, shifts tokens in the attention layer. Default is True.
|
| 654 |
+
"""
|
| 655 |
+
def __init__(
|
| 656 |
+
self,
|
| 657 |
+
*,
|
| 658 |
+
dim,
|
| 659 |
+
depth,
|
| 660 |
+
group_size=256,
|
| 661 |
+
query_key_dim=128,
|
| 662 |
+
expansion_factor=4.0,
|
| 663 |
+
causal=False,
|
| 664 |
+
attn_dropout=0.1,
|
| 665 |
+
norm_type='scalenorm',
|
| 666 |
+
shift_tokens=True
|
| 667 |
+
):
|
| 668 |
+
super().__init__()
|
| 669 |
+
|
| 670 |
+
# Ensure normalization type is valid
|
| 671 |
+
assert norm_type in ('scalenorm', 'layernorm'), 'norm_type must be one of scalenorm or layernorm'
|
| 672 |
+
|
| 673 |
+
# Select normalization class based on the provided type
|
| 674 |
+
if norm_type == 'scalenorm':
|
| 675 |
+
norm_klass = ScaleNorm
|
| 676 |
+
elif norm_type == 'layernorm':
|
| 677 |
+
norm_klass = nn.LayerNorm
|
| 678 |
+
|
| 679 |
+
self.group_size = group_size # Group size for normalization
|
| 680 |
+
|
| 681 |
+
# Rotary positional embedding for attention
|
| 682 |
+
rotary_pos_emb = RotaryEmbedding(dim=min(32, query_key_dim))
|
| 683 |
+
# Max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J
|
| 684 |
+
|
| 685 |
+
# Create a list of attention layers using FLASH_ShareA_FFConvM
|
| 686 |
+
self.layers = nn.ModuleList([
|
| 687 |
+
FLASH_ShareA_FFConvM(
|
| 688 |
+
dim=dim,
|
| 689 |
+
group_size=group_size,
|
| 690 |
+
query_key_dim=query_key_dim,
|
| 691 |
+
expansion_factor=expansion_factor,
|
| 692 |
+
causal=causal,
|
| 693 |
+
dropout=attn_dropout,
|
| 694 |
+
rotary_pos_emb=rotary_pos_emb,
|
| 695 |
+
norm_klass=norm_klass,
|
| 696 |
+
shift_tokens=shift_tokens
|
| 697 |
+
) for _ in range(depth)
|
| 698 |
+
])
|
| 699 |
+
|
| 700 |
+
def _build_repeats(self, in_channels, out_channels, lorder, hidden_size, repeats=1):
|
| 701 |
+
"""
|
| 702 |
+
Builds repeated UniDeep FSMN layers.
|
| 703 |
+
|
| 704 |
+
Args:
|
| 705 |
+
in_channels (int): Number of input channels.
|
| 706 |
+
out_channels (int): Number of output channels.
|
| 707 |
+
lorder (int): Order of the filter for FSMN.
|
| 708 |
+
hidden_size (int): Number of hidden units.
|
| 709 |
+
repeats (int, optional): Number of repetitions. Default is 1.
|
| 710 |
+
|
| 711 |
+
Returns:
|
| 712 |
+
Sequential: A sequential container with repeated layers.
|
| 713 |
+
"""
|
| 714 |
+
repeats = [
|
| 715 |
+
UniDeepFsmn(in_channels, out_channels, lorder, hidden_size)
|
| 716 |
+
for _ in range(repeats)
|
| 717 |
+
]
|
| 718 |
+
return nn.Sequential(*repeats)
|
| 719 |
+
|
| 720 |
+
def forward(self, x, *, mask=None):
|
| 721 |
+
"""
|
| 722 |
+
Forward pass for the Mossformer Block.
|
| 723 |
+
|
| 724 |
+
Args:
|
| 725 |
+
x (Tensor): Input tensor of shape (batch_size, dim, sequence_length).
|
| 726 |
+
mask (Tensor, optional): Mask tensor for attention operations.
|
| 727 |
+
|
| 728 |
+
Returns:
|
| 729 |
+
Tensor: Output tensor after processing through the block.
|
| 730 |
+
"""
|
| 731 |
+
# Process input through each attention layer
|
| 732 |
+
for flash in self.layers:
|
| 733 |
+
x = flash(x, mask=mask) # Apply attention layer with optional mask
|
| 734 |
+
|
| 735 |
+
return x # Return the final output tensor
|
models/mossformer2_sr/mossformer2_sr_wrapper.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from models.mossformer2_sr.generator import Mossformer, Generator
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class MossFormer2_SR_48K(nn.Module):
|
| 5 |
+
"""
|
| 6 |
+
The MossFormer2_SR_48K model for speech super-resolution.
|
| 7 |
+
|
| 8 |
+
This class encapsulates the functionality of the MossFormer2 and HiFi-Gan
|
| 9 |
+
Generator within a higher-level model. It processes input audio data to produce
|
| 10 |
+
higher-resolution outputs.
|
| 11 |
+
|
| 12 |
+
Arguments
|
| 13 |
+
---------
|
| 14 |
+
args : Namespace
|
| 15 |
+
Configuration arguments that may include hyperparameters
|
| 16 |
+
and model settings (not utilized in this implementation but
|
| 17 |
+
can be extended for flexibility).
|
| 18 |
+
|
| 19 |
+
Example
|
| 20 |
+
---------
|
| 21 |
+
>>> model = MossFormer2_SR_48K(args).model
|
| 22 |
+
>>> x = torch.randn(10, 180, 2000) # Example input
|
| 23 |
+
>>> outputs = model(x) # Forward pass
|
| 24 |
+
>>> outputs.shape, mask.shape # Check output shapes
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, args):
|
| 28 |
+
super(MossFormer2_SR_48K, self).__init__()
|
| 29 |
+
# Initialize the TestNet model, which contains the MossFormer MaskNet
|
| 30 |
+
self.model_m = Mossformer() # Instance of TestNet
|
| 31 |
+
self.model_g = Generator(args)
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
"""
|
| 35 |
+
Forward pass through the model.
|
| 36 |
+
|
| 37 |
+
Arguments
|
| 38 |
+
---------
|
| 39 |
+
x : torch.Tensor
|
| 40 |
+
Input tensor of dimension [B, N, S], where B is the batch size,
|
| 41 |
+
N is the number of mel bins (80 in this case), and S is the
|
| 42 |
+
sequence length (e.g., time frames).
|
| 43 |
+
|
| 44 |
+
Returns
|
| 45 |
+
-------
|
| 46 |
+
outputs : torch.Tensor
|
| 47 |
+
Bandwidth expanded audio output tensor from the model.
|
| 48 |
+
|
| 49 |
+
"""
|
| 50 |
+
x = self.model_m(x) # Get outputs and mask from TestNet
|
| 51 |
+
outpus = self.model_g(x)
|
| 52 |
+
return outputs # Return the outputs
|
models/mossformer2_sr/snake.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
from torch.nn.utils import weight_norm
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def WNConv1d(*args, **kwargs):
|
| 10 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def WNConvTranspose1d(*args, **kwargs):
|
| 14 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# Scripting this brings model speed up 1.4x
|
| 18 |
+
@torch.jit.script
|
| 19 |
+
def snake(x, alpha):
|
| 20 |
+
shape = x.shape
|
| 21 |
+
x = x.reshape(shape[0], shape[1], -1)
|
| 22 |
+
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
| 23 |
+
x = x.reshape(shape)
|
| 24 |
+
return x
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Snake1d(nn.Module):
|
| 28 |
+
def __init__(self, channels):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
| 31 |
+
|
| 32 |
+
def forward(self, x):
|
| 33 |
+
return snake(x, self.alpha)
|
models/mossformer2_sr/utils.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
from torch.nn.utils import weight_norm
|
| 5 |
+
|
| 6 |
+
def init_weights(m, mean=0.0, std=0.01):
|
| 7 |
+
classname = m.__class__.__name__
|
| 8 |
+
if classname.find("Conv") != -1:
|
| 9 |
+
m.weight.data.normal_(mean, std)
|
| 10 |
+
|
| 11 |
+
def apply_weight_norm(m):
|
| 12 |
+
classname = m.__class__.__name__
|
| 13 |
+
if classname.find("Conv") != -1:
|
| 14 |
+
weight_norm(m)
|
| 15 |
+
|
| 16 |
+
def get_padding(kernel_size, dilation=1):
|
| 17 |
+
return int((kernel_size*dilation - dilation)/2)
|
| 18 |
+
|
| 19 |
+
def load_checkpoint(filepath, device):
|
| 20 |
+
assert os.path.isfile(filepath)
|
| 21 |
+
print("Loading '{}'".format(filepath))
|
| 22 |
+
checkpoint_dict = torch.load(filepath, map_location=device)
|
| 23 |
+
print("Complete.")
|
| 24 |
+
return checkpoint_dict
|
| 25 |
+
|
| 26 |
+
def save_checkpoint(filepath, obj):
|
| 27 |
+
print("Saving checkpoint to {}".format(filepath))
|
| 28 |
+
torch.save(obj, filepath)
|
| 29 |
+
print("Complete.")
|
| 30 |
+
|
| 31 |
+
def scan_checkpoint(cp_dir, prefix):
|
| 32 |
+
pattern = os.path.join(cp_dir, prefix + '????????')
|
| 33 |
+
cp_list = glob.glob(pattern)
|
| 34 |
+
if len(cp_list) == 0:
|
| 35 |
+
return None
|
| 36 |
+
return sorted(cp_list)[-1]
|
| 37 |
+
|