JYP2024 commited on
Commit
70cbc33
·
verified ·
1 Parent(s): bbf6298

Initial upload (hf_transfer enabled)

Browse files
SSL_WavLM/WavLM.py ADDED
@@ -0,0 +1,762 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/wavlm
4
+ # Copyright (c) 2021 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import logging
12
+ from typing import List, Optional, Tuple
13
+
14
+ import numpy as np
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from torch.nn import LayerNorm
20
+ from .modules import (
21
+ Fp32GroupNorm,
22
+ Fp32LayerNorm,
23
+ GradMultiply,
24
+ MultiheadAttention,
25
+ SamePad,
26
+ init_bert_params,
27
+ get_activation_fn,
28
+ TransposeLast,
29
+ GLU_Linear,
30
+ )
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ def compute_mask_indices(
36
+ shape: Tuple[int, int],
37
+ padding_mask: Optional[torch.Tensor],
38
+ mask_prob: float,
39
+ mask_length: int,
40
+ mask_type: str = "static",
41
+ mask_other: float = 0.0,
42
+ min_masks: int = 0,
43
+ no_overlap: bool = False,
44
+ min_space: int = 0,
45
+ ) -> np.ndarray:
46
+ """
47
+ Computes random mask spans for a given shape
48
+
49
+ Args:
50
+ shape: the the shape for which to compute masks.
51
+ should be of size 2 where first element is batch size and 2nd is timesteps
52
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
53
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
54
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
55
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
56
+ mask_type: how to compute mask lengths
57
+ static = fixed size
58
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
59
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
60
+ poisson = sample from possion distribution with lambda = mask length
61
+ min_masks: minimum number of masked spans
62
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
63
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
64
+ """
65
+
66
+ bsz, all_sz = shape
67
+ mask = np.full((bsz, all_sz), False)
68
+
69
+ all_num_mask = int(
70
+ # add a random number for probabilistic rounding
71
+ mask_prob * all_sz / float(mask_length)
72
+ + np.random.rand()
73
+ )
74
+
75
+ all_num_mask = max(min_masks, all_num_mask)
76
+
77
+ mask_idcs = []
78
+ for i in range(bsz):
79
+ if padding_mask is not None:
80
+ sz = all_sz - padding_mask[i].long().sum().item()
81
+ num_mask = int(
82
+ # add a random number for probabilistic rounding
83
+ mask_prob * sz / float(mask_length)
84
+ + np.random.rand()
85
+ )
86
+ num_mask = max(min_masks, num_mask)
87
+ else:
88
+ sz = all_sz
89
+ num_mask = all_num_mask
90
+
91
+ if mask_type == "static":
92
+ lengths = np.full(num_mask, mask_length)
93
+ elif mask_type == "uniform":
94
+ lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
95
+ elif mask_type == "normal":
96
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
97
+ lengths = [max(1, int(round(x))) for x in lengths]
98
+ elif mask_type == "poisson":
99
+ lengths = np.random.poisson(mask_length, size=num_mask)
100
+ lengths = [int(round(x)) for x in lengths]
101
+ else:
102
+ raise Exception("unknown mask selection " + mask_type)
103
+
104
+ if sum(lengths) == 0:
105
+ lengths[0] = min(mask_length, sz - 1)
106
+
107
+ if no_overlap:
108
+ mask_idc = []
109
+
110
+ def arrange(s, e, length, keep_length):
111
+ span_start = np.random.randint(s, e - length)
112
+ mask_idc.extend(span_start + i for i in range(length))
113
+
114
+ new_parts = []
115
+ if span_start - s - min_space >= keep_length:
116
+ new_parts.append((s, span_start - min_space + 1))
117
+ if e - span_start - keep_length - min_space > keep_length:
118
+ new_parts.append((span_start + length + min_space, e))
119
+ return new_parts
120
+
121
+ parts = [(0, sz)]
122
+ min_length = min(lengths)
123
+ for length in sorted(lengths, reverse=True):
124
+ lens = np.fromiter(
125
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
126
+ np.int,
127
+ )
128
+ l_sum = np.sum(lens)
129
+ if l_sum == 0:
130
+ break
131
+ probs = lens / np.sum(lens)
132
+ c = np.random.choice(len(parts), p=probs)
133
+ s, e = parts.pop(c)
134
+ parts.extend(arrange(s, e, length, min_length))
135
+ mask_idc = np.asarray(mask_idc)
136
+ else:
137
+ min_len = min(lengths)
138
+ if sz - min_len <= num_mask:
139
+ min_len = sz - num_mask - 1
140
+
141
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
142
+
143
+ mask_idc = np.asarray(
144
+ [
145
+ mask_idc[j] + offset
146
+ for j in range(len(mask_idc))
147
+ for offset in range(lengths[j])
148
+ ]
149
+ )
150
+
151
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
152
+
153
+ min_len = min([len(m) for m in mask_idcs])
154
+ for i, mask_idc in enumerate(mask_idcs):
155
+ if len(mask_idc) > min_len:
156
+ mask_idc = np.random.choice(mask_idc, min_len, replace=False)
157
+ mask[i, mask_idc] = True
158
+
159
+ return mask
160
+
161
+
162
+ class WavLMConfig:
163
+ def __init__(self, cfg=None):
164
+ self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
165
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
166
+
167
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
168
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
169
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
170
+ self.activation_fn: str = "gelu" # activation function to use
171
+
172
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
173
+ self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
174
+ self.conv_bias: bool = False # include bias in conv encoder
175
+ self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this
176
+
177
+ self.normalize: bool = False # normalize input to have 0 mean and unit variance during training
178
+
179
+ # dropouts
180
+ self.dropout: float = 0.1 # dropout probability for the transformer
181
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
182
+ self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
183
+ self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
184
+ self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
185
+ self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr)
186
+
187
+ # masking
188
+ self.mask_length: int = 10 # mask length
189
+ self.mask_prob: float = 0.65 # probability of replacing a token with mask
190
+ self.mask_selection: str = "static" # how to choose mask length
191
+ self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
192
+ self.no_mask_overlap: bool = False # whether to allow masks to overlap
193
+ self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled)
194
+
195
+ # channel masking
196
+ self.mask_channel_length: int = 10 # length of the mask for features (channels)
197
+ self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0
198
+ self.mask_channel_selection: str = "static" # how to choose mask length for channel masking
199
+ self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
200
+ self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap
201
+ self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled)
202
+
203
+ # positional embeddings
204
+ self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
205
+ self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
206
+
207
+ # relative position embedding
208
+ self.relative_position_embedding: bool = False # apply relative position embedding
209
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
210
+ self.max_distance: int = 1280 # maximum distance for relative position embedding
211
+ self.gru_rel_pos: bool = False # apply gated relative position embedding
212
+ self.adapter_dim: int = 128
213
+
214
+ if cfg is not None:
215
+ self.update(cfg)
216
+
217
+ def update(self, cfg: dict):
218
+ self.__dict__.update(cfg)
219
+
220
+
221
+ class WavLM(nn.Module):
222
+ def __init__(
223
+ self,
224
+ cfg: WavLMConfig,
225
+ ) -> None:
226
+ super().__init__()
227
+ logger.info(f"WavLM Config: {cfg.__dict__}")
228
+
229
+ self.cfg = cfg
230
+ feature_enc_layers = eval(cfg.conv_feature_layers)
231
+ self.embed = feature_enc_layers[-1][0]
232
+
233
+ self.feature_extractor = ConvFeatureExtractionModel(
234
+ conv_layers=feature_enc_layers,
235
+ dropout=0.0,
236
+ mode=cfg.extractor_mode,
237
+ conv_bias=cfg.conv_bias,
238
+ )
239
+
240
+ self.post_extract_proj = (
241
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
242
+ if self.embed != cfg.encoder_embed_dim
243
+ else None
244
+ )
245
+
246
+ self.mask_prob = cfg.mask_prob
247
+ self.mask_selection = cfg.mask_selection
248
+ self.mask_other = cfg.mask_other
249
+ self.mask_length = cfg.mask_length
250
+ self.no_mask_overlap = cfg.no_mask_overlap
251
+ self.mask_min_space = cfg.mask_min_space
252
+
253
+ self.mask_channel_prob = cfg.mask_channel_prob
254
+ self.mask_channel_selection = cfg.mask_channel_selection
255
+ self.mask_channel_other = cfg.mask_channel_other
256
+ self.mask_channel_length = cfg.mask_channel_length
257
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
258
+ self.mask_channel_min_space = cfg.mask_channel_min_space
259
+
260
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
261
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
262
+
263
+ self.feature_grad_mult = cfg.feature_grad_mult
264
+
265
+ self.encoder = TransformerEncoder(cfg)
266
+ self.layer_norm = LayerNorm(self.embed)
267
+
268
+ def apply_mask(self, x, padding_mask):
269
+ B, T, C = x.shape
270
+ if self.mask_prob > 0:
271
+ mask_indices = compute_mask_indices(
272
+ (B, T),
273
+ padding_mask,
274
+ self.mask_prob,
275
+ self.mask_length,
276
+ self.mask_selection,
277
+ self.mask_other,
278
+ min_masks=2,
279
+ no_overlap=self.no_mask_overlap,
280
+ min_space=self.mask_min_space,
281
+ )
282
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
283
+ x[mask_indices] = self.mask_emb
284
+ else:
285
+ mask_indices = None
286
+
287
+ if self.mask_channel_prob > 0:
288
+ mask_channel_indices = compute_mask_indices(
289
+ (B, C),
290
+ None,
291
+ self.mask_channel_prob,
292
+ self.mask_channel_length,
293
+ self.mask_channel_selection,
294
+ self.mask_channel_other,
295
+ no_overlap=self.no_mask_channel_overlap,
296
+ min_space=self.mask_channel_min_space,
297
+ )
298
+ mask_channel_indices = (
299
+ torch.from_numpy(mask_channel_indices)
300
+ .to(x.device)
301
+ .unsqueeze(1)
302
+ .expand(-1, T, -1)
303
+ )
304
+ x[mask_channel_indices] = 0
305
+
306
+ return x, mask_indices
307
+
308
+ def forward_padding_mask(
309
+ self, features: torch.Tensor, padding_mask: torch.Tensor,
310
+ ) -> torch.Tensor:
311
+ extra = padding_mask.size(1) % features.size(1)
312
+ if extra > 0:
313
+ padding_mask = padding_mask[:, :-extra]
314
+ padding_mask = padding_mask.view(
315
+ padding_mask.size(0), features.size(1), -1
316
+ )
317
+ padding_mask = padding_mask.all(-1)
318
+ return padding_mask
319
+
320
+ def extract_features(
321
+ self,
322
+ source: torch.Tensor,
323
+ padding_mask: Optional[torch.Tensor] = None,
324
+ mask: bool = False,
325
+ ret_conv: bool = False,
326
+ output_layer: Optional[int] = None,
327
+ ret_layer_results: bool = False,
328
+ ):
329
+ if self.feature_grad_mult > 0:
330
+ features = self.feature_extractor(source)
331
+ features = features[-1].transpose(1, 2)
332
+ if self.feature_grad_mult != 1.0:
333
+ features = GradMultiply.apply(features, self.feature_grad_mult)
334
+ else:
335
+ with torch.no_grad():
336
+ features = self.feature_extractor(source)
337
+ features = features[-1].transpose(1, 2)
338
+
339
+
340
+ cnn_outs = features
341
+ features = self.layer_norm(features)
342
+
343
+ if padding_mask is not None:
344
+ padding_mask = self.forward_padding_mask(features, padding_mask)
345
+
346
+ if self.post_extract_proj is not None:
347
+ features = self.post_extract_proj(features)
348
+
349
+ features = self.dropout_input(features)
350
+
351
+ if mask:
352
+ x, mask_indices = self.apply_mask(
353
+ features, padding_mask
354
+ )
355
+ else:
356
+ x = features
357
+
358
+ # feature: (B, T, D), float
359
+ # target: (B, T), long
360
+ # x: (B, T, D), float
361
+ # padding_mask: (B, T), bool
362
+ # mask_indices: (B, T), bool
363
+ x, layer_results = self.encoder(
364
+ x,
365
+ padding_mask=padding_mask,
366
+ layer=None if output_layer is None else output_layer - 1
367
+ )
368
+ return cnn_outs, layer_results
369
+ # res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results}
370
+
371
+ # feature = res["features"] if ret_conv else res["x"]
372
+ # if ret_layer_results:
373
+ # feature = (feature, res["layer_results"])
374
+ # return feature, res["padding_mask"]
375
+
376
+
377
+ class ConvFeatureExtractionModel(nn.Module):
378
+ def __init__(
379
+ self,
380
+ conv_layers: List[Tuple[int, int, int]],
381
+ dropout: float = 0.0,
382
+ mode: str = "default",
383
+ conv_bias: bool = False,
384
+ conv_type: str = "default"
385
+ ):
386
+ super().__init__()
387
+
388
+ assert mode in {"default", "layer_norm"}
389
+
390
+ def block(
391
+ n_in,
392
+ n_out,
393
+ k,
394
+ stride,
395
+ is_layer_norm=False,
396
+ is_group_norm=False,
397
+ conv_bias=False,
398
+ ):
399
+ def make_conv():
400
+ conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
401
+ nn.init.kaiming_normal_(conv.weight)
402
+ return conv
403
+
404
+ assert (
405
+ is_layer_norm and is_group_norm
406
+ ) == False, "layer norm and group norm are exclusive"
407
+
408
+ if is_layer_norm:
409
+ return nn.Sequential(
410
+ make_conv(),
411
+ nn.Dropout(p=dropout),
412
+ nn.Sequential(
413
+ TransposeLast(),
414
+ Fp32LayerNorm(dim, elementwise_affine=True),
415
+ TransposeLast(),
416
+ ),
417
+ nn.GELU(),
418
+ )
419
+ elif is_group_norm:
420
+ return nn.Sequential(
421
+ make_conv(),
422
+ nn.Dropout(p=dropout),
423
+ Fp32GroupNorm(dim, dim, affine=True),
424
+ nn.GELU(),
425
+ )
426
+ else:
427
+ return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
428
+
429
+ self.conv_type = conv_type
430
+ if self.conv_type == "default":
431
+ in_d = 1
432
+ self.conv_layers = nn.ModuleList()
433
+ for i, cl in enumerate(conv_layers):
434
+ assert len(cl) == 3, "invalid conv definition: " + str(cl)
435
+ (dim, k, stride) = cl
436
+
437
+ self.conv_layers.append(
438
+ block(
439
+ in_d,
440
+ dim,
441
+ k,
442
+ stride,
443
+ is_layer_norm=mode == "layer_norm",
444
+ is_group_norm=mode == "default" and i == 0,
445
+ conv_bias=conv_bias,
446
+ )
447
+ )
448
+ in_d = dim
449
+ elif self.conv_type == "conv2d":
450
+ in_d = 1
451
+ self.conv_layers = nn.ModuleList()
452
+ for i, cl in enumerate(conv_layers):
453
+ assert len(cl) == 3
454
+ (dim, k, stride) = cl
455
+
456
+ self.conv_layers.append(
457
+ torch.nn.Conv2d(in_d, dim, k, stride)
458
+ )
459
+ self.conv_layers.append(torch.nn.ReLU())
460
+ in_d = dim
461
+ elif self.conv_type == "custom":
462
+ in_d = 1
463
+ idim = 80
464
+ self.conv_layers = nn.ModuleList()
465
+ for i, cl in enumerate(conv_layers):
466
+ assert len(cl) == 3
467
+ (dim, k, stride) = cl
468
+ self.conv_layers.append(
469
+ torch.nn.Conv2d(in_d, dim, k, stride, padding=1)
470
+ )
471
+ self.conv_layers.append(
472
+ torch.nn.LayerNorm([dim, idim])
473
+ )
474
+ self.conv_layers.append(torch.nn.ReLU())
475
+ in_d = dim
476
+ if (i + 1) % 2 == 0:
477
+ self.conv_layers.append(
478
+ torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
479
+ )
480
+ idim = int(math.ceil(idim / 2))
481
+ else:
482
+ pass
483
+ def forward(self, x, mask=None):
484
+
485
+ # BxT -> BxCxT
486
+ x_lst = []
487
+ x = x.unsqueeze(1)
488
+ if self.conv_type == "custom":
489
+ for conv in self.conv_layers:
490
+ if isinstance(conv, nn.LayerNorm):
491
+ x = x.transpose(1, 2)
492
+ x = conv(x).transpose(1, 2)
493
+ else:
494
+ x = conv(x)
495
+ x = x.transpose(2, 3).contiguous()
496
+ x = x.view(x.size(0), -1, x.size(-1))
497
+ else:
498
+ for conv in self.conv_layers:
499
+ x = conv(x)
500
+ x_lst.append(x)
501
+ if self.conv_type == "conv2d":
502
+ b, c, t, f = x.size()
503
+ x = x.transpose(2, 3).contiguous().view(b, c * f, t)
504
+ return x_lst
505
+
506
+
507
+ class TransformerEncoder(nn.Module):
508
+ def __init__(self, args):
509
+ super().__init__()
510
+
511
+ self.dropout = args.dropout
512
+ self.embedding_dim = args.encoder_embed_dim
513
+
514
+ self.pos_conv = nn.Conv1d(
515
+ self.embedding_dim,
516
+ self.embedding_dim,
517
+ kernel_size=args.conv_pos,
518
+ padding=args.conv_pos // 2,
519
+ groups=args.conv_pos_groups,
520
+ )
521
+ dropout = 0
522
+ std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
523
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
524
+ nn.init.constant_(self.pos_conv.bias, 0)
525
+
526
+ self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
527
+ self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
528
+
529
+ if hasattr(args, "relative_position_embedding"):
530
+ self.relative_position_embedding = args.relative_position_embedding
531
+ self.num_buckets = args.num_buckets
532
+ self.max_distance = args.max_distance
533
+ else:
534
+ self.relative_position_embedding = False
535
+ self.num_buckets = 0
536
+ self.max_distance = 0
537
+
538
+ self.layers = nn.ModuleList(
539
+ [
540
+ TransformerSentenceEncoderLayer(
541
+ embedding_dim=self.embedding_dim,
542
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
543
+ num_attention_heads=args.encoder_attention_heads,
544
+ dropout=self.dropout,
545
+ attention_dropout=args.attention_dropout,
546
+ activation_dropout=args.activation_dropout,
547
+ activation_fn=args.activation_fn,
548
+ layer_norm_first=args.layer_norm_first,
549
+ has_relative_attention_bias=(self.relative_position_embedding and i == 0),
550
+ num_buckets=self.num_buckets,
551
+ max_distance=self.max_distance,
552
+ gru_rel_pos=args.gru_rel_pos,
553
+ adapter_dim=args.adapter_dim,
554
+ )
555
+ for i in range(args.encoder_layers)
556
+ ]
557
+ )
558
+
559
+ self.layer_norm_first = args.layer_norm_first
560
+ self.layer_norm = LayerNorm(self.embedding_dim)
561
+ self.layerdrop = args.encoder_layerdrop
562
+
563
+ def __prepare_scriptable__(self):
564
+ for hook in self.pos_conv._forward_pre_hooks.values():
565
+ # The hook we want to remove is an instance of WeightNorm class, so
566
+ # normally we would do `if isinstance(...)` but this class is not accessible
567
+ # because of shadowing, so we check the module name directly.
568
+ # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
569
+ if hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm":
570
+ _LG.warning("Removing weight_norm from %s", self.__class__.__name__)
571
+ torch.nn.utils.remove_weight_norm(self.pos_conv)
572
+ return self
573
+
574
+ def forward(self, x, padding_mask=None, streaming_mask=None, layer=None):
575
+ x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer)
576
+
577
+ if self.layer_norm_first and layer is None:
578
+ x = self.layer_norm(x)
579
+
580
+ return x, layer_results
581
+
582
+ def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None):
583
+
584
+ if padding_mask is not None:
585
+ x[padding_mask] = 0
586
+
587
+ x_conv = self.pos_conv(x.transpose(1, 2))
588
+ x_conv = x_conv.transpose(1, 2)
589
+ x = x + x_conv
590
+
591
+ if not self.layer_norm_first:
592
+ x = self.layer_norm(x)
593
+
594
+ x = F.dropout(x, p=self.dropout, training=self.training)
595
+
596
+ # B x T x C -> T x B x C
597
+ x = x.transpose(0, 1)
598
+
599
+ layer_results = []
600
+ z = None
601
+ if tgt_layer is not None:
602
+ layer_results.append((x, z))
603
+ r = None
604
+ pos_bias = None
605
+ for i, layer in enumerate(self.layers):
606
+ # dropout_probability = np.random.random()
607
+ if not self.training or (torch.rand(1).item() > self.layerdrop):
608
+ x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False,
609
+ self_attn_mask=streaming_mask, pos_bias=pos_bias)
610
+ if tgt_layer is not None:
611
+ layer_results.append((x, z))
612
+ if i == tgt_layer:
613
+ r = x
614
+ break
615
+
616
+ if r is not None:
617
+ x = r
618
+
619
+ # T x B x C -> B x T x C
620
+ x = x.transpose(0, 1)
621
+
622
+ return x, layer_results
623
+
624
+
625
+ class TransformerSentenceEncoderLayer(nn.Module):
626
+ """
627
+ Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
628
+ models.
629
+ """
630
+
631
+ def __init__(
632
+ self,
633
+ embedding_dim: float = 768,
634
+ ffn_embedding_dim: float = 3072,
635
+ num_attention_heads: float = 8,
636
+ dropout: float = 0.1,
637
+ attention_dropout: float = 0.1,
638
+ activation_dropout: float = 0.1,
639
+ activation_fn: str = "relu",
640
+ layer_norm_first: bool = False,
641
+ has_relative_attention_bias: bool = False,
642
+ num_buckets: int = 0,
643
+ max_distance: int = 0,
644
+ rescale_init: bool = False,
645
+ gru_rel_pos: bool = False,
646
+ adapter_dim: int = 128,
647
+ ) -> None:
648
+
649
+ super().__init__()
650
+ # Initialize parameters
651
+ self.embedding_dim = embedding_dim
652
+ self.dropout = dropout
653
+ self.activation_dropout = activation_dropout
654
+
655
+ # Initialize blocks
656
+ self.activation_name = activation_fn
657
+ self.activation_fn = get_activation_fn(activation_fn)
658
+ self.self_attn = MultiheadAttention(
659
+ self.embedding_dim,
660
+ num_attention_heads,
661
+ dropout=attention_dropout,
662
+ self_attention=True,
663
+ has_relative_attention_bias=has_relative_attention_bias,
664
+ num_buckets=num_buckets,
665
+ max_distance=max_distance,
666
+ rescale_init=rescale_init,
667
+ gru_rel_pos=gru_rel_pos,
668
+ )
669
+
670
+ self.dropout1 = nn.Dropout(dropout)
671
+ self.dropout2 = nn.Dropout(self.activation_dropout)
672
+ self.dropout3 = nn.Dropout(dropout)
673
+
674
+ self.layer_norm_first = layer_norm_first
675
+
676
+ # layer norm associated with the self attention layer
677
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
678
+
679
+ if self.activation_name == "glu":
680
+ self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
681
+ else:
682
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
683
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
684
+
685
+ import torchaudio.functional as AudioF
686
+
687
+
688
+ # layer norm associated with the position wise feed-forward NN
689
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
690
+
691
+ def forward(
692
+ self,
693
+ x: torch.Tensor,
694
+ self_attn_mask: torch.Tensor = None,
695
+ self_attn_padding_mask: torch.Tensor = None,
696
+ need_weights: bool = False,
697
+ pos_bias=None
698
+ ):
699
+ """
700
+ LayerNorm is applied either before or after the self-attention/ffn
701
+ modules similar to the original Transformer imlementation.
702
+ """
703
+ residual = x
704
+
705
+ if self.layer_norm_first:
706
+ x = self.self_attn_layer_norm(x)
707
+ x, attn, pos_bias = self.self_attn(
708
+ query=x,
709
+ key=x,
710
+ value=x,
711
+ key_padding_mask=self_attn_padding_mask,
712
+ need_weights=False,
713
+ attn_mask=self_attn_mask,
714
+ position_bias=pos_bias
715
+ )
716
+ x = self.dropout1(x)
717
+ x = residual + x
718
+
719
+ residual = x
720
+
721
+ x = self.final_layer_norm(x)
722
+ if self.activation_name == "glu":
723
+ x = self.fc1(x)
724
+ else:
725
+ x = self.activation_fn(self.fc1(x))
726
+ x = self.dropout2(x)
727
+ x = self.fc2(x)
728
+ x = self.dropout3(x)
729
+ x = residual + x
730
+ else:
731
+
732
+
733
+ x, attn, pos_bias = self.self_attn(
734
+ query=x,
735
+ key=x,
736
+ value=x,
737
+ key_padding_mask=self_attn_padding_mask,
738
+ need_weights=need_weights,
739
+ attn_mask=self_attn_mask,
740
+ position_bias=pos_bias
741
+ )
742
+
743
+ x = self.dropout1(x)
744
+ x = residual + x
745
+
746
+ x = self.self_attn_layer_norm(x)
747
+ # MAM
748
+
749
+ residual = x
750
+ if self.activation_name == "glu":
751
+ x = self.fc1(x)
752
+ else:
753
+ x = self.activation_fn(self.fc1(x))
754
+ x = self.dropout2(x)
755
+ x = self.fc2(x)
756
+ x = self.dropout3(x)
757
+ x = residual + x
758
+
759
+
760
+ x = self.final_layer_norm(x)
761
+
762
+ return x, attn, pos_bias
SSL_WavLM/__init__.py ADDED
File without changes
SSL_WavLM/__pycache__/WavLM.cpython-310.pyc ADDED
Binary file (17.2 kB). View file
 
SSL_WavLM/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (200 Bytes). View file
 
SSL_WavLM/__pycache__/modules.cpython-310.pyc ADDED
Binary file (19.3 kB). View file
 
SSL_WavLM/model_convert.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b7dfae027dabfae1fc2efd919a3cebe8f12a22eb6724e2856aaefd5d06c172d
3
+ size 401044050
SSL_WavLM/modules.py ADDED
@@ -0,0 +1,827 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/wavlm
4
+ # Copyright (c) 2021 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import warnings
12
+ from typing import Dict, Optional, Tuple
13
+ import torch
14
+ from torch import Tensor, nn
15
+ from torch.nn import Parameter
16
+ import torch.nn.functional as F
17
+
18
+
19
+ class TransposeLast(nn.Module):
20
+ def __init__(self, deconstruct_idx=None):
21
+ super().__init__()
22
+ self.deconstruct_idx = deconstruct_idx
23
+
24
+ def forward(self, x):
25
+ if self.deconstruct_idx is not None:
26
+ x = x[self.deconstruct_idx]
27
+ return x.transpose(-2, -1)
28
+
29
+
30
+ class Fp32LayerNorm(nn.LayerNorm):
31
+ def __init__(self, *args, **kwargs):
32
+ super().__init__(*args, **kwargs)
33
+
34
+ def forward(self, input):
35
+ output = F.layer_norm(
36
+ input.float(),
37
+ self.normalized_shape,
38
+ self.weight.float() if self.weight is not None else None,
39
+ self.bias.float() if self.bias is not None else None,
40
+ self.eps,
41
+ )
42
+ return output.type_as(input)
43
+
44
+
45
+ class Fp32GroupNorm(nn.GroupNorm):
46
+ def __init__(self, *args, **kwargs):
47
+ super().__init__(*args, **kwargs)
48
+
49
+ def forward(self, input):
50
+ output = F.group_norm(
51
+ input.float(),
52
+ self.num_groups,
53
+ self.weight.float() if self.weight is not None else None,
54
+ self.bias.float() if self.bias is not None else None,
55
+ self.eps,
56
+ )
57
+ return output.type_as(input)
58
+
59
+
60
+ class GradMultiply(torch.autograd.Function):
61
+ @staticmethod
62
+ def forward(ctx, x, scale):
63
+ ctx.scale = scale
64
+ res = x.new(x)
65
+ return res
66
+
67
+ @staticmethod
68
+ def backward(ctx, grad):
69
+ return grad * ctx.scale, None
70
+
71
+
72
+ class SamePad(nn.Module):
73
+ def __init__(self, kernel_size, causal=False):
74
+ super().__init__()
75
+ if causal:
76
+ self.remove = kernel_size - 1
77
+ else:
78
+ self.remove = 1 if kernel_size % 2 == 0 else 0
79
+
80
+ def forward(self, x):
81
+ if self.remove > 0:
82
+ x = x[:, :, : -self.remove]
83
+ return x
84
+
85
+
86
+ class Swish(nn.Module):
87
+ """Swish function
88
+ """
89
+
90
+ def __init__(self):
91
+ """Construct an MultiHeadedAttention object."""
92
+ super(Swish, self).__init__()
93
+ self.act = torch.nn.Sigmoid()
94
+
95
+ def forward(self, x):
96
+ return x * self.act(x)
97
+
98
+
99
+ class GLU_Linear(nn.Module):
100
+ def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
101
+ super(GLU_Linear, self).__init__()
102
+
103
+ self.glu_type = glu_type
104
+ self.output_dim = output_dim
105
+
106
+ if glu_type == "sigmoid":
107
+ self.glu_act = torch.nn.Sigmoid()
108
+ elif glu_type == "swish":
109
+ self.glu_act = Swish()
110
+ elif glu_type == "relu":
111
+ self.glu_act = torch.nn.ReLU()
112
+ elif glu_type == "gelu":
113
+ self.glu_act = torch.nn.GELU()
114
+
115
+ if bias_in_glu:
116
+ self.linear = nn.Linear(input_dim, output_dim * 2, True)
117
+ else:
118
+ self.linear = nn.Linear(input_dim, output_dim * 2, False)
119
+
120
+ def forward(self, x):
121
+ # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
122
+ x = self.linear(x)
123
+
124
+ if self.glu_type == "bilinear":
125
+ x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
126
+ else:
127
+ x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
128
+
129
+ return x
130
+
131
+
132
+ def gelu_accurate(x):
133
+ if not hasattr(gelu_accurate, "_a"):
134
+ gelu_accurate._a = math.sqrt(2 / math.pi)
135
+ return (
136
+ 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
137
+ )
138
+
139
+
140
+ def gelu(x: torch.Tensor) -> torch.Tensor:
141
+ return torch.nn.functional.gelu(x.float()).type_as(x)
142
+
143
+
144
+ def get_activation_fn(activation: str):
145
+ """Returns the activation function corresponding to `activation`"""
146
+
147
+ if activation == "relu":
148
+ return F.relu
149
+ elif activation == "gelu":
150
+ return gelu
151
+ elif activation == "gelu_fast":
152
+ warnings.warn(
153
+ "--activation-fn=gelu_fast has been renamed to gelu_accurate"
154
+ )
155
+ return gelu_accurate
156
+ elif activation == "gelu_accurate":
157
+ return gelu_accurate
158
+ elif activation == "tanh":
159
+ return torch.tanh
160
+ elif activation == "linear":
161
+ return lambda x: x
162
+ elif activation == "glu":
163
+ return lambda x: x
164
+ else:
165
+ raise RuntimeError("--activation-fn {} not supported".format(activation))
166
+
167
+
168
+ def init_bert_params(module):
169
+ """
170
+ Initialize the weights specific to the BERT Model.
171
+ This overrides the default initializations depending on the specified arguments.
172
+ 1. If normal_init_linear_weights is set then weights of linear
173
+ layer will be initialized using the normal distribution and
174
+ bais will be set to the specified value.
175
+ 2. If normal_init_embed_weights is set then weights of embedding
176
+ layer will be initialized using the normal distribution.
177
+ 3. If normal_init_proj_weights is set then weights of
178
+ in_project_weight for MultiHeadAttention initialized using
179
+ the normal distribution (to be validated).
180
+ """
181
+
182
+ def normal_(data):
183
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
184
+ # so that the RNG is consistent with and without FSDP
185
+ data.copy_(
186
+ data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
187
+ )
188
+
189
+ if isinstance(module, nn.Linear):
190
+ normal_(module.weight.data)
191
+ if module.bias is not None:
192
+ module.bias.data.zero_()
193
+ if isinstance(module, nn.Embedding):
194
+ normal_(module.weight.data)
195
+ if module.padding_idx is not None:
196
+ module.weight.data[module.padding_idx].zero_()
197
+ if isinstance(module, MultiheadAttention):
198
+ normal_(module.q_proj.weight.data)
199
+ normal_(module.k_proj.weight.data)
200
+ normal_(module.v_proj.weight.data)
201
+
202
+
203
+ def quant_noise(module, p, block_size):
204
+ """
205
+ Wraps modules and applies quantization noise to the weights for
206
+ subsequent quantization with Iterative Product Quantization as
207
+ described in "Training with Quantization Noise for Extreme Model Compression"
208
+
209
+ Args:
210
+ - module: nn.Module
211
+ - p: amount of Quantization Noise
212
+ - block_size: size of the blocks for subsequent quantization with iPQ
213
+
214
+ Remarks:
215
+ - Module weights must have the right sizes wrt the block size
216
+ - Only Linear, Embedding and Conv2d modules are supported for the moment
217
+ - For more detail on how to quantize by blocks with convolutional weights,
218
+ see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
219
+ - We implement the simplest form of noise here as stated in the paper
220
+ which consists in randomly dropping blocks
221
+ """
222
+
223
+ # if no quantization noise, don't register hook
224
+ if p <= 0:
225
+ return module
226
+
227
+ # supported modules
228
+ assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
229
+
230
+ # test whether module.weight has the right sizes wrt block_size
231
+ is_conv = module.weight.ndim == 4
232
+
233
+ # 2D matrix
234
+ if not is_conv:
235
+ assert (
236
+ module.weight.size(1) % block_size == 0
237
+ ), "Input features must be a multiple of block sizes"
238
+
239
+ # 4D matrix
240
+ else:
241
+ # 1x1 convolutions
242
+ if module.kernel_size == (1, 1):
243
+ assert (
244
+ module.in_channels % block_size == 0
245
+ ), "Input channels must be a multiple of block sizes"
246
+ # regular convolutions
247
+ else:
248
+ k = module.kernel_size[0] * module.kernel_size[1]
249
+ assert k % block_size == 0, "Kernel size must be a multiple of block size"
250
+
251
+ def _forward_pre_hook(mod, input):
252
+ # no noise for evaluation
253
+ if mod.training:
254
+ if not is_conv:
255
+ # gather weight and sizes
256
+ weight = mod.weight
257
+ in_features = weight.size(1)
258
+ out_features = weight.size(0)
259
+
260
+ # split weight matrix into blocks and randomly drop selected blocks
261
+ mask = torch.zeros(
262
+ in_features // block_size * out_features, device=weight.device
263
+ )
264
+ mask.bernoulli_(p)
265
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
266
+
267
+ else:
268
+ # gather weight and sizes
269
+ weight = mod.weight
270
+ in_channels = mod.in_channels
271
+ out_channels = mod.out_channels
272
+
273
+ # split weight matrix into blocks and randomly drop selected blocks
274
+ if mod.kernel_size == (1, 1):
275
+ mask = torch.zeros(
276
+ int(in_channels // block_size * out_channels),
277
+ device=weight.device,
278
+ )
279
+ mask.bernoulli_(p)
280
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
281
+ else:
282
+ mask = torch.zeros(
283
+ weight.size(0), weight.size(1), device=weight.device
284
+ )
285
+ mask.bernoulli_(p)
286
+ mask = (
287
+ mask.unsqueeze(2)
288
+ .unsqueeze(3)
289
+ .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
290
+ )
291
+
292
+ # scale weights and apply mask
293
+ mask = mask.to(
294
+ torch.bool
295
+ ) # x.bool() is not currently supported in TorchScript
296
+ s = 1 / (1 - p)
297
+ mod.weight.data = s * weight.masked_fill(mask, 0)
298
+
299
+ module.register_forward_pre_hook(_forward_pre_hook)
300
+ return module
301
+
302
+
303
+ class MultiheadAttention(nn.Module):
304
+ """Multi-headed attention.
305
+
306
+ See "Attention Is All You Need" for more details.
307
+ """
308
+
309
+ def __init__(
310
+ self,
311
+ embed_dim,
312
+ num_heads,
313
+ kdim=None,
314
+ vdim=None,
315
+ dropout=0.0,
316
+ bias=True,
317
+ add_bias_kv=False,
318
+ add_zero_attn=False,
319
+ self_attention=False,
320
+ encoder_decoder_attention=False,
321
+ q_noise=0.0,
322
+ qn_block_size=8,
323
+ has_relative_attention_bias=False,
324
+ num_buckets=32,
325
+ max_distance=128,
326
+ gru_rel_pos=False,
327
+ rescale_init=False,
328
+ ):
329
+ super().__init__()
330
+ self.embed_dim = embed_dim
331
+ self.kdim = kdim if kdim is not None else embed_dim
332
+ self.vdim = vdim if vdim is not None else embed_dim
333
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
334
+
335
+ self.num_heads = num_heads
336
+ self.dropout_module = nn.Dropout(dropout)
337
+
338
+ self.has_relative_attention_bias = has_relative_attention_bias
339
+ self.num_buckets = num_buckets
340
+ self.max_distance = max_distance
341
+ if self.has_relative_attention_bias:
342
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
343
+
344
+ self.head_dim = embed_dim // num_heads
345
+ self.q_head_dim = self.head_dim
346
+ self.k_head_dim = self.head_dim
347
+ assert (
348
+ self.head_dim * num_heads == self.embed_dim
349
+ ), "embed_dim must be divisible by num_heads"
350
+ self.scaling = self.head_dim ** -0.5
351
+
352
+ self.self_attention = self_attention
353
+ self.encoder_decoder_attention = encoder_decoder_attention
354
+
355
+ assert not self.self_attention or self.qkv_same_dim, (
356
+ "Self-attention requires query, key and " "value to be of the same size"
357
+ )
358
+
359
+ k_bias = True
360
+ if rescale_init:
361
+ k_bias = False
362
+
363
+ k_embed_dim = embed_dim
364
+ q_embed_dim = embed_dim
365
+
366
+ self.k_proj = quant_noise(
367
+ nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
368
+ )
369
+ self.v_proj = quant_noise(
370
+ nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
371
+ )
372
+ self.q_proj = quant_noise(
373
+ nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
374
+ )
375
+
376
+ self.out_proj = quant_noise(
377
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
378
+ )
379
+
380
+ if add_bias_kv:
381
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
382
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
383
+ else:
384
+ self.bias_k = self.bias_v = None
385
+
386
+ self.add_zero_attn = add_zero_attn
387
+
388
+ self.gru_rel_pos = gru_rel_pos
389
+ if self.gru_rel_pos:
390
+ self.grep_linear = nn.Linear(self.q_head_dim, 8)
391
+ self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
392
+
393
+ self.reset_parameters()
394
+
395
+ def reset_parameters(self):
396
+ if self.qkv_same_dim:
397
+ # Empirically observed the convergence to be much better with
398
+ # the scaled initialization
399
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
400
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
401
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
402
+ else:
403
+ nn.init.xavier_uniform_(self.k_proj.weight)
404
+ nn.init.xavier_uniform_(self.v_proj.weight)
405
+ nn.init.xavier_uniform_(self.q_proj.weight)
406
+
407
+ nn.init.xavier_uniform_(self.out_proj.weight)
408
+ if self.out_proj.bias is not None:
409
+ nn.init.constant_(self.out_proj.bias, 0.0)
410
+ if self.bias_k is not None:
411
+ nn.init.xavier_normal_(self.bias_k)
412
+ if self.bias_v is not None:
413
+ nn.init.xavier_normal_(self.bias_v)
414
+ if self.has_relative_attention_bias:
415
+ nn.init.xavier_normal_(self.relative_attention_bias.weight)
416
+
417
+ def _relative_positions_bucket(self, relative_positions, bidirectional=True):
418
+ num_buckets = self.num_buckets
419
+ max_distance = self.max_distance
420
+ relative_buckets = 0
421
+
422
+ if bidirectional:
423
+ num_buckets = num_buckets // 2
424
+ relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
425
+ relative_positions = torch.abs(relative_positions)
426
+ else:
427
+ relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
428
+
429
+ max_exact = num_buckets // 2
430
+ is_small = relative_positions < max_exact
431
+
432
+ relative_postion_if_large = max_exact + (
433
+ torch.log(relative_positions.float() / max_exact)
434
+ / math.log(max_distance / max_exact)
435
+ * (num_buckets - max_exact)
436
+ ).to(torch.long)
437
+ relative_postion_if_large = torch.min(
438
+ relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
439
+ )
440
+
441
+ relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
442
+ return relative_buckets
443
+
444
+ def compute_bias(self, query_length, key_length):
445
+ context_position = torch.arange(query_length, dtype=torch.long)[:, None]
446
+ memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
447
+ relative_position = memory_position - context_position
448
+ relative_position_bucket = self._relative_positions_bucket(
449
+ relative_position,
450
+ bidirectional=True
451
+ )
452
+ relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
453
+ values = self.relative_attention_bias(relative_position_bucket)
454
+ values = values.permute([2, 0, 1])
455
+ return values
456
+
457
+ def forward(
458
+ self,
459
+ query,
460
+ key: Optional[Tensor],
461
+ value: Optional[Tensor],
462
+ key_padding_mask: Optional[Tensor] = None,
463
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
464
+ need_weights: bool = True,
465
+ static_kv: bool = False,
466
+ attn_mask: Optional[Tensor] = None,
467
+ before_softmax: bool = False,
468
+ need_head_weights: bool = False,
469
+ position_bias: Optional[Tensor] = None
470
+ ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
471
+ """Input shape: Time x Batch x Channel
472
+
473
+ Args:
474
+ key_padding_mask (ByteTensor, optional): mask to exclude
475
+ keys that are pads, of shape `(batch, src_len)`, where
476
+ padding elements are indicated by 1s.
477
+ need_weights (bool, optional): return the attention weights,
478
+ averaged over heads (default: False).
479
+ attn_mask (ByteTensor, optional): typically used to
480
+ implement causal attention, where the mask prevents the
481
+ attention from looking forward in time (default: None).
482
+ before_softmax (bool, optional): return the raw attention
483
+ weights and values before the attention softmax.
484
+ need_head_weights (bool, optional): return the attention
485
+ weights for each head. Implies *need_weights*. Default:
486
+ return the average attention weights over all heads.
487
+ """
488
+ if need_head_weights:
489
+ need_weights = True
490
+
491
+ is_tpu = query.device.type == "xla"
492
+
493
+ tgt_len, bsz, embed_dim = query.size()
494
+ src_len = tgt_len
495
+ assert embed_dim == self.embed_dim
496
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
497
+ if key is not None:
498
+ src_len, key_bsz, _ = key.size()
499
+ if not torch.jit.is_scripting():
500
+ assert key_bsz == bsz
501
+ assert value is not None
502
+ assert src_len, bsz == value.shape[:2]
503
+
504
+ if self.has_relative_attention_bias and position_bias is None:
505
+ position_bias = self.compute_bias(tgt_len, src_len)
506
+ position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
507
+
508
+ if (
509
+ not is_tpu # don't use PyTorch version on TPUs
510
+ and incremental_state is None
511
+ and not static_kv
512
+ # A workaround for quantization to work. Otherwise JIT compilation
513
+ # treats bias in linear module as method.
514
+ and not torch.jit.is_scripting()
515
+ and self.q_head_dim == self.head_dim
516
+ ):
517
+ assert key is not None and value is not None
518
+ assert attn_mask is None
519
+
520
+ attn_mask_rel_pos = None
521
+ if position_bias is not None:
522
+ attn_mask_rel_pos = position_bias
523
+ if self.gru_rel_pos:
524
+ query_layer = query.transpose(0, 1)
525
+ new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1)
526
+ query_layer = query_layer.view(*new_x_shape)
527
+ query_layer = query_layer.permute(0, 2, 1, 3)
528
+ _B, _H, _L, __ = query_layer.size()
529
+
530
+ gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
531
+ _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
532
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
533
+ attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
534
+
535
+ attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
536
+ k_proj_bias = self.k_proj.bias
537
+ if k_proj_bias is None:
538
+ k_proj_bias = torch.zeros_like(self.q_proj.bias)
539
+
540
+ x, attn = F.multi_head_attention_forward(
541
+ query,
542
+ key,
543
+ value,
544
+ self.embed_dim,
545
+ self.num_heads,
546
+ torch.empty([0]),
547
+ torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
548
+ self.bias_k,
549
+ self.bias_v,
550
+ self.add_zero_attn,
551
+ self.dropout_module.p,
552
+ self.out_proj.weight,
553
+ self.out_proj.bias,
554
+ self.training,
555
+ # self.training or self.dropout_module.apply_during_inference,
556
+ key_padding_mask,
557
+ need_weights,
558
+ attn_mask_rel_pos,
559
+ use_separate_proj_weight=True,
560
+ q_proj_weight=self.q_proj.weight,
561
+ k_proj_weight=self.k_proj.weight,
562
+ v_proj_weight=self.v_proj.weight,
563
+ )
564
+ return x, attn, position_bias
565
+
566
+ if incremental_state is not None:
567
+ saved_state = self._get_input_buffer(incremental_state)
568
+ if saved_state is not None and "prev_key" in saved_state:
569
+ # previous time steps are cached - no need to recompute
570
+ # key and value if they are static
571
+ if static_kv:
572
+ assert self.encoder_decoder_attention and not self.self_attention
573
+ key = value = None
574
+ else:
575
+ saved_state = None
576
+
577
+ if self.self_attention:
578
+ q = self.q_proj(query)
579
+ k = self.k_proj(query)
580
+ v = self.v_proj(query)
581
+ elif self.encoder_decoder_attention:
582
+ # encoder-decoder attention
583
+ q = self.q_proj(query)
584
+ if key is None:
585
+ assert value is None
586
+ k = v = None
587
+ else:
588
+ k = self.k_proj(key)
589
+ v = self.v_proj(key)
590
+
591
+ else:
592
+ assert key is not None and value is not None
593
+ q = self.q_proj(query)
594
+ k = self.k_proj(key)
595
+ v = self.v_proj(value)
596
+ q *= self.scaling
597
+
598
+ if self.bias_k is not None:
599
+ assert self.bias_v is not None
600
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
601
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
602
+ if attn_mask is not None:
603
+ attn_mask = torch.cat(
604
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
605
+ )
606
+ if key_padding_mask is not None:
607
+ key_padding_mask = torch.cat(
608
+ [
609
+ key_padding_mask,
610
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
611
+ ],
612
+ dim=1,
613
+ )
614
+
615
+ q = (
616
+ q.contiguous()
617
+ .view(tgt_len, bsz * self.num_heads, self.q_head_dim)
618
+ .transpose(0, 1)
619
+ )
620
+ if k is not None:
621
+ k = (
622
+ k.contiguous()
623
+ .view(-1, bsz * self.num_heads, self.k_head_dim)
624
+ .transpose(0, 1)
625
+ )
626
+ if v is not None:
627
+ v = (
628
+ v.contiguous()
629
+ .view(-1, bsz * self.num_heads, self.head_dim)
630
+ .transpose(0, 1)
631
+ )
632
+
633
+ if saved_state is not None:
634
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
635
+ if "prev_key" in saved_state:
636
+ _prev_key = saved_state["prev_key"]
637
+ assert _prev_key is not None
638
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
639
+ if static_kv:
640
+ k = prev_key
641
+ else:
642
+ assert k is not None
643
+ k = torch.cat([prev_key, k], dim=1)
644
+ src_len = k.size(1)
645
+ if "prev_value" in saved_state:
646
+ _prev_value = saved_state["prev_value"]
647
+ assert _prev_value is not None
648
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
649
+ if static_kv:
650
+ v = prev_value
651
+ else:
652
+ assert v is not None
653
+ v = torch.cat([prev_value, v], dim=1)
654
+ prev_key_padding_mask: Optional[Tensor] = None
655
+ if "prev_key_padding_mask" in saved_state:
656
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
657
+ assert k is not None and v is not None
658
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
659
+ key_padding_mask=key_padding_mask,
660
+ prev_key_padding_mask=prev_key_padding_mask,
661
+ batch_size=bsz,
662
+ src_len=k.size(1),
663
+ static_kv=static_kv,
664
+ )
665
+
666
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
667
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
668
+ saved_state["prev_key_padding_mask"] = key_padding_mask
669
+ # In this branch incremental_state is never None
670
+ assert incremental_state is not None
671
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
672
+ assert k is not None
673
+ assert k.size(1) == src_len
674
+
675
+ # This is part of a workaround to get around fork/join parallelism
676
+ # not supporting Optional types.
677
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
678
+ key_padding_mask = None
679
+
680
+ if key_padding_mask is not None:
681
+ assert key_padding_mask.size(0) == bsz
682
+ assert key_padding_mask.size(1) == src_len
683
+
684
+ if self.add_zero_attn:
685
+ assert v is not None
686
+ src_len += 1
687
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
688
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
689
+ if attn_mask is not None:
690
+ attn_mask = torch.cat(
691
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
692
+ )
693
+ if key_padding_mask is not None:
694
+ key_padding_mask = torch.cat(
695
+ [
696
+ key_padding_mask,
697
+ torch.zeros(key_padding_mask.size(0), 1).type_as(
698
+ key_padding_mask
699
+ ),
700
+ ],
701
+ dim=1,
702
+ )
703
+
704
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
705
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
706
+
707
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
708
+
709
+ if attn_mask is not None:
710
+ attn_mask = attn_mask.unsqueeze(0)
711
+ attn_weights += attn_mask
712
+
713
+ if key_padding_mask is not None:
714
+ # don't attend to padding symbols
715
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
716
+ if not is_tpu:
717
+ attn_weights = attn_weights.masked_fill(
718
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
719
+ float("-inf"),
720
+ )
721
+ else:
722
+ attn_weights = attn_weights.transpose(0, 2)
723
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
724
+ attn_weights = attn_weights.transpose(0, 2)
725
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
726
+
727
+ if before_softmax:
728
+ return attn_weights, v, position_bias
729
+
730
+ if position_bias is not None:
731
+ if self.gru_rel_pos == 1:
732
+ query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
733
+ _B, _H, _L, __ = query_layer.size()
734
+ gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
735
+ _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
736
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
737
+ position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
738
+
739
+ position_bias = position_bias.view(attn_weights.size())
740
+
741
+ attn_weights = attn_weights + position_bias
742
+
743
+ attn_weights_float = F.softmax(
744
+ attn_weights, dim=-1
745
+ )
746
+ attn_weights = attn_weights_float.type_as(attn_weights)
747
+ attn_probs = self.dropout_module(attn_weights)
748
+
749
+ assert v is not None
750
+ attn = torch.bmm(attn_probs, v)
751
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
752
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
753
+ attn = self.out_proj(attn)
754
+ attn_weights: Optional[Tensor] = None
755
+ if need_weights:
756
+ attn_weights = attn_weights_float.view(
757
+ bsz, self.num_heads, tgt_len, src_len
758
+ ).transpose(1, 0)
759
+ if not need_head_weights:
760
+ # average attention weights over heads
761
+ attn_weights = attn_weights.mean(dim=0)
762
+
763
+ return attn, attn_weights, position_bias
764
+
765
+ @staticmethod
766
+ def _append_prev_key_padding_mask(
767
+ key_padding_mask: Optional[Tensor],
768
+ prev_key_padding_mask: Optional[Tensor],
769
+ batch_size: int,
770
+ src_len: int,
771
+ static_kv: bool,
772
+ ) -> Optional[Tensor]:
773
+ # saved key padding masks have shape (bsz, seq_len)
774
+ if prev_key_padding_mask is not None and static_kv:
775
+ new_key_padding_mask = prev_key_padding_mask
776
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
777
+ new_key_padding_mask = torch.cat(
778
+ [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
779
+ )
780
+ # During incremental decoding, as the padding token enters and
781
+ # leaves the frame, there will be a time when prev or current
782
+ # is None
783
+ elif prev_key_padding_mask is not None:
784
+ if src_len > prev_key_padding_mask.size(1):
785
+ filler = torch.zeros(
786
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
787
+ device=prev_key_padding_mask.device,
788
+ )
789
+ new_key_padding_mask = torch.cat(
790
+ [prev_key_padding_mask.float(), filler.float()], dim=1
791
+ )
792
+ else:
793
+ new_key_padding_mask = prev_key_padding_mask.float()
794
+ elif key_padding_mask is not None:
795
+ if src_len > key_padding_mask.size(1):
796
+ filler = torch.zeros(
797
+ (batch_size, src_len - key_padding_mask.size(1)),
798
+ device=key_padding_mask.device,
799
+ )
800
+ new_key_padding_mask = torch.cat(
801
+ [filler.float(), key_padding_mask.float()], dim=1
802
+ )
803
+ else:
804
+ new_key_padding_mask = key_padding_mask.float()
805
+ else:
806
+ new_key_padding_mask = prev_key_padding_mask
807
+ return new_key_padding_mask
808
+
809
+ def _get_input_buffer(
810
+ self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
811
+ ) -> Dict[str, Optional[Tensor]]:
812
+ result = self.get_incremental_state(incremental_state, "attn_state")
813
+ if result is not None:
814
+ return result
815
+ else:
816
+ empty_result: Dict[str, Optional[Tensor]] = {}
817
+ return empty_result
818
+
819
+ def _set_input_buffer(
820
+ self,
821
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
822
+ buffer: Dict[str, Optional[Tensor]],
823
+ ):
824
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
825
+
826
+ def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
827
+ return attn_weights
Transformer_WavLM.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from SSL_WavLM.WavLM import WavLMConfig, WavLM
6
+ from collections import OrderedDict
7
+
8
+
9
+ class MHFA(nn.Module):
10
+ """
11
+ Multi-Head Factorized Attentive (MHFA) Pooling.
12
+ This layer takes representations from all layers of a model (like WavLM)
13
+ and aggregates them into a fixed-size embedding using a multi-head
14
+ attention-like mechanism.
15
+ """
16
+ def __init__(self, head_nb=8, inputs_dim=768, compression_dim=128, outputs_dim=256, nb_layer=13):
17
+ super(MHFA, self).__init__()
18
+ # Learnable weights to compute a weighted average over the layers
19
+ self.weights_k = nn.Parameter(torch.ones(nb_layer), requires_grad=True)
20
+ self.weights_v = nn.Parameter(torch.ones(nb_layer), requires_grad=True)
21
+
22
+ self.head_nb = head_nb
23
+ self.cmp_dim = compression_dim
24
+
25
+ # Linear layers for processing
26
+ self.cmp_linear_k = nn.Linear(inputs_dim, self.cmp_dim)
27
+ self.cmp_linear_v = nn.Linear(inputs_dim, self.cmp_dim)
28
+ self.att_head = nn.Linear(self.cmp_dim, self.head_nb)
29
+ self.pooling_fc = nn.Linear(self.head_nb * self.cmp_dim, outputs_dim)
30
+
31
+ def forward(self, x):
32
+ # Input x shape: [Batch, Dim, Frame_len, Nb_Layer]
33
+
34
+ # 1. Compute weighted average for Key and Value across layers
35
+ # The softmax ensures the weights sum to 1.
36
+ k = torch.sum(x * F.softmax(self.weights_k, dim=-1), dim=-1).transpose(1, 2)
37
+ v = torch.sum(x * F.softmax(self.weights_v, dim=-1), dim=-1).transpose(1, 2)
38
+ # Shape of k, v is now [Batch, Frame_len, Dim]
39
+
40
+ # 2. Compress Key and Value representations
41
+ k = self.cmp_linear_k(k) # -> [B, T, cmp_dim]
42
+ v = self.cmp_linear_v(v) # -> [B, T, cmp_dim]
43
+
44
+ # 3. Compute attention scores from the compressed key
45
+ att_scores = self.att_head(k) # -> [B, T, head_nb]
46
+ att_weights = F.softmax(att_scores, dim=1) # Softmax over time dimension
47
+
48
+ # 4. Perform attention-pooling
49
+ # Reshape for broadcasting:
50
+ # v: [B, T, 1, cmp_dim]
51
+ # att_weights: [B, T, head_nb, 1]
52
+ # The multiplication broadcasts to [B, T, head_nb, cmp_dim]
53
+ pooled_features = torch.sum(v.unsqueeze(-2) * att_weights.unsqueeze(-1), dim=1)
54
+ # Sum over time dimension results in [B, head_nb, cmp_dim]
55
+
56
+ # 5. Flatten and project to final output dimension
57
+ b, h, f = pooled_features.shape
58
+ pooled_features = pooled_features.reshape(b, -1) # -> [B, head_nb * cmp_dim]
59
+ output_embedding = self.pooling_fc(pooled_features) # -> [B, outputs_dim]
60
+
61
+ return output_embedding
62
+
63
+ class WavLM_MHFA(nn.Module):
64
+ """
65
+ The main model that combines a pre-trained WavLM with the MHFA backend.
66
+ """
67
+ def __init__(self, model_path):
68
+ super(WavLM_MHFA, self).__init__()
69
+
70
+ print(f"Loading base model checkpoint from: {model_path}")
71
+ # Use map_location to ensure it works on CPU if no GPU is available
72
+ checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
73
+
74
+ # Correctly access the config dictionary
75
+ cfg_dict = checkpoint['cfg']
76
+ cfg = WavLMConfig(cfg_dict)
77
+ self.model = WavLM(cfg)
78
+
79
+ inputs_dim = checkpoint['cfg']['encoder_embed_dim']
80
+ nb_layer = checkpoint['cfg']['encoder_layers'] + 1
81
+
82
+ self.back_end = MHFA(inputs_dim=inputs_dim, head_nb=32, outputs_dim=256, nb_layer=nb_layer)
83
+
84
+ # Load the pre-trained weights for the WavLM part of the model
85
+ self.load_checkpoint(checkpoint['model'])
86
+
87
+ def load_checkpoint(self, checkpoint_state):
88
+ loaded_state = checkpoint_state
89
+
90
+ # Create a new state_dict to hold the cleaned keys
91
+ cleaned_state_dict = OrderedDict()
92
+
93
+ # Handle checkpoints that might be nested (e.g., inside a 'speaker_extractor')
94
+ prefix_to_strip = 'speaker_extractor.'
95
+ for k, v in loaded_state.items():
96
+ if 'projection' in k:
97
+ continue
98
+ if k.startswith(prefix_to_strip):
99
+ cleaned_key = k[len(prefix_to_strip):]
100
+ cleaned_state_dict[cleaned_key] = v
101
+ else:
102
+ cleaned_state_dict[k] = v
103
+
104
+
105
+ # Now load the cleaned state_dict into the current model
106
+ super().load_state_dict(cleaned_state_dict, strict=True)
107
+ print("Successfully loaded weights for both WavLM and MHFA backend.")
108
+
109
+ def forward(self, raw_wav):
110
+ # Feature extraction should not require gradients and should be in eval mode
111
+
112
+ _, layer_results = self.model.extract_features(raw_wav, output_layer=100)
113
+
114
+ # Prepare layer representations for the MHFA backend
115
+ # Input layer_results: List of (Time, Batch, Dim) tensors
116
+ # Stack them to create [Batch, Time, Dim, Nb_Layer]
117
+ stacked_reps = torch.stack([x.transpose(0, 1) for x, _ in layer_results], dim=-1)
118
+ # Permute to match MHFA input: [Batch, Dim, Time, Nb_Layer]
119
+ layer_reps = stacked_reps.permute(0, 2, 1, 3)
120
+
121
+ # The backend part is trainable
122
+ spk_embedding = self.back_end(layer_reps)
123
+
124
+ return spk_embedding
125
+
126
+ if __name__ == "__main__":
127
+
128
+ # Step 1: Instantiate the main model
129
+ # The model path should point to the pre-trained base model (e.g., WavLM-Base+.pt)
130
+ print("Loading checkpoint file ...")
131
+
132
+ base_model_path = './SSL_WavLM/model_convert.pt'
133
+ model = WavLM_MHFA(model_path=base_model_path)
134
+ model.eval() # Set the model to evaluation mode
135
+
136
+ print("\nModel WavLM_MHFA initialized successfully.")
137
+
138
+ # Step 2: Perform a forward pass with dummy data
139
+ batch_size = 4
140
+ audio_samples = 32000 # ~2 seconds of audio at 16kHz
141
+ dummy_wav = torch.randn(batch_size, audio_samples)
142
+
143
+ print(f"\nPerforming forward pass with dummy input of shape: {dummy_wav.shape}")
144
+
145
+ speaker_embedding = model(dummy_wav)
146
+
147
+ print("Forward pass successful!")
148
+ print(f"Output speaker embedding shape: {speaker_embedding.shape}")