ThreadAbort commited on
Commit
2bbf5a2
·
1 Parent(s): 06dfc33

chore: remove IndexTTS Python codebase

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. indextts/BigVGAN/ECAPA_TDNN.py +0 -656
  2. indextts/BigVGAN/__init__.py +0 -0
  3. indextts/BigVGAN/activations.py +0 -122
  4. indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
  5. indextts/BigVGAN/alias_free_activation/cuda/.gitignore +0 -1
  6. indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  7. indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +0 -76
  8. indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +0 -23
  9. indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +0 -256
  10. indextts/BigVGAN/alias_free_activation/cuda/compat.h +0 -29
  11. indextts/BigVGAN/alias_free_activation/cuda/load.py +0 -121
  12. indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +0 -92
  13. indextts/BigVGAN/alias_free_activation/torch/__init__.py +0 -6
  14. indextts/BigVGAN/alias_free_activation/torch/act.py +0 -31
  15. indextts/BigVGAN/alias_free_activation/torch/filter.py +0 -102
  16. indextts/BigVGAN/alias_free_activation/torch/resample.py +0 -58
  17. indextts/BigVGAN/alias_free_torch/__init__.py +0 -6
  18. indextts/BigVGAN/alias_free_torch/act.py +0 -29
  19. indextts/BigVGAN/alias_free_torch/filter.py +0 -96
  20. indextts/BigVGAN/alias_free_torch/resample.py +0 -49
  21. indextts/BigVGAN/bigvgan.py +0 -534
  22. indextts/BigVGAN/models.py +0 -451
  23. indextts/BigVGAN/nnet/CNN.py +0 -546
  24. indextts/BigVGAN/nnet/__init__.py +0 -0
  25. indextts/BigVGAN/nnet/linear.py +0 -89
  26. indextts/BigVGAN/nnet/normalization.py +0 -670
  27. indextts/BigVGAN/utils.py +0 -101
  28. indextts/__init__.py +0 -0
  29. indextts/cli.py +0 -65
  30. indextts/gpt/__init__.py +0 -0
  31. indextts/gpt/conformer/__init__.py +0 -0
  32. indextts/gpt/conformer/attention.py +0 -312
  33. indextts/gpt/conformer/embedding.py +0 -163
  34. indextts/gpt/conformer/subsampling.py +0 -348
  35. indextts/gpt/conformer_encoder.py +0 -520
  36. indextts/gpt/model.py +0 -713
  37. indextts/gpt/model_v2.py +0 -747
  38. indextts/gpt/perceiver.py +0 -317
  39. indextts/gpt/transformers_beam_search.py +0 -1013
  40. indextts/gpt/transformers_generation_utils.py +0 -0
  41. indextts/gpt/transformers_gpt2.py +0 -1878
  42. indextts/gpt/transformers_modeling_utils.py +0 -0
  43. indextts/infer.py +0 -690
  44. indextts/infer_v2.py +0 -739
  45. indextts/s2mel/dac/__init__.py +0 -16
  46. indextts/s2mel/dac/__main__.py +0 -36
  47. indextts/s2mel/dac/model/__init__.py +0 -4
  48. indextts/s2mel/dac/model/base.py +0 -294
  49. indextts/s2mel/dac/model/dac.py +0 -400
  50. indextts/s2mel/dac/model/discriminator.py +0 -228
indextts/BigVGAN/ECAPA_TDNN.py DELETED
@@ -1,656 +0,0 @@
1
- """A popular speaker recognition and diarization model.
2
-
3
- Authors
4
- * Hwidong Na 2020
5
- """
6
-
7
- import torch # noqa: F401
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
-
11
- from indextts.BigVGAN.nnet.CNN import Conv1d as _Conv1d
12
- from indextts.BigVGAN.nnet.linear import Linear
13
- from indextts.BigVGAN.nnet.normalization import BatchNorm1d as _BatchNorm1d
14
-
15
-
16
- def length_to_mask(length, max_len=None, dtype=None, device=None):
17
- """Creates a binary mask for each sequence.
18
-
19
- Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3
20
-
21
- Arguments
22
- ---------
23
- length : torch.LongTensor
24
- Containing the length of each sequence in the batch. Must be 1D.
25
- max_len : int
26
- Max length for the mask, also the size of the second dimension.
27
- dtype : torch.dtype, default: None
28
- The dtype of the generated mask.
29
- device: torch.device, default: None
30
- The device to put the mask variable.
31
-
32
- Returns
33
- -------
34
- mask : tensor
35
- The binary mask.
36
-
37
- Example
38
- -------
39
- >>> length=torch.Tensor([1,2,3])
40
- >>> mask=length_to_mask(length)
41
- >>> mask
42
- tensor([[1., 0., 0.],
43
- [1., 1., 0.],
44
- [1., 1., 1.]])
45
- """
46
- assert len(length.shape) == 1
47
-
48
- if max_len is None:
49
- max_len = length.max().long().item() # using arange to generate mask
50
- mask = torch.arange(
51
- max_len, device=length.device, dtype=length.dtype
52
- ).expand(len(length), max_len) < length.unsqueeze(1)
53
-
54
- if dtype is None:
55
- dtype = length.dtype
56
-
57
- if device is None:
58
- device = length.device
59
-
60
- mask = torch.as_tensor(mask, dtype=dtype, device=device)
61
- return mask
62
-
63
-
64
- # Skip transpose as much as possible for efficiency
65
- class Conv1d(_Conv1d):
66
- """1D convolution. Skip transpose is used to improve efficiency."""
67
-
68
- def __init__(self, *args, **kwargs):
69
- super().__init__(skip_transpose=True, *args, **kwargs)
70
-
71
-
72
- class BatchNorm1d(_BatchNorm1d):
73
- """1D batch normalization. Skip transpose is used to improve efficiency."""
74
-
75
- def __init__(self, *args, **kwargs):
76
- super().__init__(skip_transpose=True, *args, **kwargs)
77
-
78
-
79
- class TDNNBlock(nn.Module):
80
- """An implementation of TDNN.
81
-
82
- Arguments
83
- ---------
84
- in_channels : int
85
- Number of input channels.
86
- out_channels : int
87
- The number of output channels.
88
- kernel_size : int
89
- The kernel size of the TDNN blocks.
90
- dilation : int
91
- The dilation of the TDNN block.
92
- activation : torch class
93
- A class for constructing the activation layers.
94
- groups : int
95
- The groups size of the TDNN blocks.
96
-
97
- Example
98
- -------
99
- >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
100
- >>> layer = TDNNBlock(64, 64, kernel_size=3, dilation=1)
101
- >>> out_tensor = layer(inp_tensor).transpose(1, 2)
102
- >>> out_tensor.shape
103
- torch.Size([8, 120, 64])
104
- """
105
-
106
- def __init__(
107
- self,
108
- in_channels,
109
- out_channels,
110
- kernel_size,
111
- dilation,
112
- activation=nn.ReLU,
113
- groups=1,
114
- ):
115
- super().__init__()
116
- self.conv = Conv1d(
117
- in_channels=in_channels,
118
- out_channels=out_channels,
119
- kernel_size=kernel_size,
120
- dilation=dilation,
121
- groups=groups,
122
- )
123
- self.activation = activation()
124
- self.norm = BatchNorm1d(input_size=out_channels)
125
-
126
- def forward(self, x):
127
- """Processes the input tensor x and returns an output tensor."""
128
- return self.norm(self.activation(self.conv(x)))
129
-
130
-
131
- class Res2NetBlock(torch.nn.Module):
132
- """An implementation of Res2NetBlock w/ dilation.
133
-
134
- Arguments
135
- ---------
136
- in_channels : int
137
- The number of channels expected in the input.
138
- out_channels : int
139
- The number of output channels.
140
- scale : int
141
- The scale of the Res2Net block.
142
- kernel_size: int
143
- The kernel size of the Res2Net block.
144
- dilation : int
145
- The dilation of the Res2Net block.
146
-
147
- Example
148
- -------
149
- >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
150
- >>> layer = Res2NetBlock(64, 64, scale=4, dilation=3)
151
- >>> out_tensor = layer(inp_tensor).transpose(1, 2)
152
- >>> out_tensor.shape
153
- torch.Size([8, 120, 64])
154
- """
155
-
156
- def __init__(
157
- self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1
158
- ):
159
- super().__init__()
160
- assert in_channels % scale == 0
161
- assert out_channels % scale == 0
162
-
163
- in_channel = in_channels // scale
164
- hidden_channel = out_channels // scale
165
-
166
- self.blocks = nn.ModuleList(
167
- [
168
- TDNNBlock(
169
- in_channel,
170
- hidden_channel,
171
- kernel_size=kernel_size,
172
- dilation=dilation,
173
- )
174
- for i in range(scale - 1)
175
- ]
176
- )
177
- self.scale = scale
178
-
179
- def forward(self, x):
180
- """Processes the input tensor x and returns an output tensor."""
181
- y = []
182
- for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
183
- if i == 0:
184
- y_i = x_i
185
- elif i == 1:
186
- y_i = self.blocks[i - 1](x_i)
187
- else:
188
- y_i = self.blocks[i - 1](x_i + y_i)
189
- y.append(y_i)
190
- y = torch.cat(y, dim=1)
191
- return y
192
-
193
-
194
- class SEBlock(nn.Module):
195
- """An implementation of squeeze-and-excitation block.
196
-
197
- Arguments
198
- ---------
199
- in_channels : int
200
- The number of input channels.
201
- se_channels : int
202
- The number of output channels after squeeze.
203
- out_channels : int
204
- The number of output channels.
205
-
206
- Example
207
- -------
208
- >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
209
- >>> se_layer = SEBlock(64, 16, 64)
210
- >>> lengths = torch.rand((8,))
211
- >>> out_tensor = se_layer(inp_tensor, lengths).transpose(1, 2)
212
- >>> out_tensor.shape
213
- torch.Size([8, 120, 64])
214
- """
215
-
216
- def __init__(self, in_channels, se_channels, out_channels):
217
- super().__init__()
218
-
219
- self.conv1 = Conv1d(
220
- in_channels=in_channels, out_channels=se_channels, kernel_size=1
221
- )
222
- self.relu = torch.nn.ReLU(inplace=True)
223
- self.conv2 = Conv1d(
224
- in_channels=se_channels, out_channels=out_channels, kernel_size=1
225
- )
226
- self.sigmoid = torch.nn.Sigmoid()
227
-
228
- def forward(self, x, lengths=None):
229
- """Processes the input tensor x and returns an output tensor."""
230
- L = x.shape[-1]
231
- if lengths is not None:
232
- mask = length_to_mask(lengths * L, max_len=L, device=x.device)
233
- mask = mask.unsqueeze(1)
234
- total = mask.sum(dim=2, keepdim=True)
235
- s = (x * mask).sum(dim=2, keepdim=True) / total
236
- else:
237
- s = x.mean(dim=2, keepdim=True)
238
-
239
- s = self.relu(self.conv1(s))
240
- s = self.sigmoid(self.conv2(s))
241
-
242
- return s * x
243
-
244
-
245
- class AttentiveStatisticsPooling(nn.Module):
246
- """This class implements an attentive statistic pooling layer for each channel.
247
- It returns the concatenated mean and std of the input tensor.
248
-
249
- Arguments
250
- ---------
251
- channels: int
252
- The number of input channels.
253
- attention_channels: int
254
- The number of attention channels.
255
- global_context: bool
256
- Whether to use global context.
257
-
258
- Example
259
- -------
260
- >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
261
- >>> asp_layer = AttentiveStatisticsPooling(64)
262
- >>> lengths = torch.rand((8,))
263
- >>> out_tensor = asp_layer(inp_tensor, lengths).transpose(1, 2)
264
- >>> out_tensor.shape
265
- torch.Size([8, 1, 128])
266
- """
267
-
268
- def __init__(self, channels, attention_channels=128, global_context=True):
269
- super().__init__()
270
-
271
- self.eps = 1e-12
272
- self.global_context = global_context
273
- if global_context:
274
- self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1)
275
- else:
276
- self.tdnn = TDNNBlock(channels, attention_channels, 1, 1)
277
- self.tanh = nn.Tanh()
278
- self.conv = Conv1d(
279
- in_channels=attention_channels, out_channels=channels, kernel_size=1
280
- )
281
-
282
- def forward(self, x, lengths=None):
283
- """Calculates mean and std for a batch (input tensor).
284
-
285
- Arguments
286
- ---------
287
- x : torch.Tensor
288
- Tensor of shape [N, C, L].
289
- lengths : torch.Tensor
290
- The corresponding relative lengths of the inputs.
291
-
292
- Returns
293
- -------
294
- pooled_stats : torch.Tensor
295
- mean and std of batch
296
- """
297
- L = x.shape[-1]
298
-
299
- def _compute_statistics(x, m, dim=2, eps=self.eps):
300
- mean = (m * x).sum(dim)
301
- std = torch.sqrt(
302
- (m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps)
303
- )
304
- return mean, std
305
-
306
- if lengths is None:
307
- lengths = torch.ones(x.shape[0], device=x.device)
308
-
309
- # Make binary mask of shape [N, 1, L]
310
- mask = length_to_mask(lengths * L, max_len=L, device=x.device)
311
- mask = mask.unsqueeze(1)
312
-
313
- # Expand the temporal context of the pooling layer by allowing the
314
- # self-attention to look at global properties of the utterance.
315
- if self.global_context:
316
- # torch.std is unstable for backward computation
317
- # https://github.com/pytorch/pytorch/issues/4320
318
- total = mask.sum(dim=2, keepdim=True).float()
319
- mean, std = _compute_statistics(x, mask / total)
320
- mean = mean.unsqueeze(2).repeat(1, 1, L)
321
- std = std.unsqueeze(2).repeat(1, 1, L)
322
- attn = torch.cat([x, mean, std], dim=1)
323
- else:
324
- attn = x
325
-
326
- # Apply layers
327
- attn = self.conv(self.tanh(self.tdnn(attn)))
328
-
329
- # Filter out zero-paddings
330
- attn = attn.masked_fill(mask == 0, float("-inf"))
331
-
332
- attn = F.softmax(attn, dim=2)
333
- mean, std = _compute_statistics(x, attn)
334
- # Append mean and std of the batch
335
- pooled_stats = torch.cat((mean, std), dim=1)
336
- pooled_stats = pooled_stats.unsqueeze(2)
337
-
338
- return pooled_stats
339
-
340
-
341
- class SERes2NetBlock(nn.Module):
342
- """An implementation of building block in ECAPA-TDNN, i.e.,
343
- TDNN-Res2Net-TDNN-SEBlock.
344
-
345
- Arguments
346
- ---------
347
- in_channels: int
348
- Expected size of input channels.
349
- out_channels: int
350
- The number of output channels.
351
- res2net_scale: int
352
- The scale of the Res2Net block.
353
- se_channels : int
354
- The number of output channels after squeeze.
355
- kernel_size: int
356
- The kernel size of the TDNN blocks.
357
- dilation: int
358
- The dilation of the Res2Net block.
359
- activation : torch class
360
- A class for constructing the activation layers.
361
- groups: int
362
- Number of blocked connections from input channels to output channels.
363
-
364
- Example
365
- -------
366
- >>> x = torch.rand(8, 120, 64).transpose(1, 2)
367
- >>> conv = SERes2NetBlock(64, 64, res2net_scale=4)
368
- >>> out = conv(x).transpose(1, 2)
369
- >>> out.shape
370
- torch.Size([8, 120, 64])
371
- """
372
-
373
- def __init__(
374
- self,
375
- in_channels,
376
- out_channels,
377
- res2net_scale=8,
378
- se_channels=128,
379
- kernel_size=1,
380
- dilation=1,
381
- activation=torch.nn.ReLU,
382
- groups=1,
383
- ):
384
- super().__init__()
385
- self.out_channels = out_channels
386
- self.tdnn1 = TDNNBlock(
387
- in_channels,
388
- out_channels,
389
- kernel_size=1,
390
- dilation=1,
391
- activation=activation,
392
- groups=groups,
393
- )
394
- self.res2net_block = Res2NetBlock(
395
- out_channels, out_channels, res2net_scale, kernel_size, dilation
396
- )
397
- self.tdnn2 = TDNNBlock(
398
- out_channels,
399
- out_channels,
400
- kernel_size=1,
401
- dilation=1,
402
- activation=activation,
403
- groups=groups,
404
- )
405
- self.se_block = SEBlock(out_channels, se_channels, out_channels)
406
-
407
- self.shortcut = None
408
- if in_channels != out_channels:
409
- self.shortcut = Conv1d(
410
- in_channels=in_channels,
411
- out_channels=out_channels,
412
- kernel_size=1,
413
- )
414
-
415
- def forward(self, x, lengths=None):
416
- """Processes the input tensor x and returns an output tensor."""
417
- residual = x
418
- if self.shortcut:
419
- residual = self.shortcut(x)
420
-
421
- x = self.tdnn1(x)
422
- x = self.res2net_block(x)
423
- x = self.tdnn2(x)
424
- x = self.se_block(x, lengths)
425
-
426
- return x + residual
427
-
428
-
429
- class ECAPA_TDNN(torch.nn.Module):
430
- """An implementation of the speaker embedding model in a paper.
431
- "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in
432
- TDNN Based Speaker Verification" (https://arxiv.org/abs/2005.07143).
433
-
434
- Arguments
435
- ---------
436
- input_size : int
437
- Expected size of the input dimension.
438
- device : str
439
- Device used, e.g., "cpu" or "cuda".
440
- lin_neurons : int
441
- Number of neurons in linear layers.
442
- activation : torch class
443
- A class for constructing the activation layers.
444
- channels : list of ints
445
- Output channels for TDNN/SERes2Net layer.
446
- kernel_sizes : list of ints
447
- List of kernel sizes for each layer.
448
- dilations : list of ints
449
- List of dilations for kernels in each layer.
450
- attention_channels: int
451
- The number of attention channels.
452
- res2net_scale : int
453
- The scale of the Res2Net block.
454
- se_channels : int
455
- The number of output channels after squeeze.
456
- global_context: bool
457
- Whether to use global context.
458
- groups : list of ints
459
- List of groups for kernels in each layer.
460
-
461
- Example
462
- -------
463
- >>> input_feats = torch.rand([5, 120, 80])
464
- >>> compute_embedding = ECAPA_TDNN(80, lin_neurons=192)
465
- >>> outputs = compute_embedding(input_feats)
466
- >>> outputs.shape
467
- torch.Size([5, 1, 192])
468
- """
469
-
470
- def __init__(
471
- self,
472
- input_size,
473
- device="cpu",
474
- lin_neurons=192,
475
- activation=torch.nn.ReLU,
476
- channels=[512, 512, 512, 512, 1536],
477
- kernel_sizes=[5, 3, 3, 3, 1],
478
- dilations=[1, 2, 3, 4, 1],
479
- attention_channels=128,
480
- res2net_scale=8,
481
- se_channels=128,
482
- global_context=True,
483
- groups=[1, 1, 1, 1, 1],
484
- ):
485
- super().__init__()
486
- assert len(channels) == len(kernel_sizes)
487
- assert len(channels) == len(dilations)
488
- self.channels = channels
489
- self.blocks = nn.ModuleList()
490
-
491
- # The initial TDNN layer
492
- self.blocks.append(
493
- TDNNBlock(
494
- input_size,
495
- channels[0],
496
- kernel_sizes[0],
497
- dilations[0],
498
- activation,
499
- groups[0],
500
- )
501
- )
502
-
503
- # SE-Res2Net layers
504
- for i in range(1, len(channels) - 1):
505
- self.blocks.append(
506
- SERes2NetBlock(
507
- channels[i - 1],
508
- channels[i],
509
- res2net_scale=res2net_scale,
510
- se_channels=se_channels,
511
- kernel_size=kernel_sizes[i],
512
- dilation=dilations[i],
513
- activation=activation,
514
- groups=groups[i],
515
- )
516
- )
517
-
518
- # Multi-layer feature aggregation
519
- self.mfa = TDNNBlock(
520
- channels[-2] * (len(channels) - 2),
521
- channels[-1],
522
- kernel_sizes[-1],
523
- dilations[-1],
524
- activation,
525
- groups=groups[-1],
526
- )
527
-
528
- # Attentive Statistical Pooling
529
- self.asp = AttentiveStatisticsPooling(
530
- channels[-1],
531
- attention_channels=attention_channels,
532
- global_context=global_context,
533
- )
534
- self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2)
535
-
536
- # Final linear transformation
537
- self.fc = Conv1d(
538
- in_channels=channels[-1] * 2,
539
- out_channels=lin_neurons,
540
- kernel_size=1,
541
- )
542
-
543
- def forward(self, x, lengths=None):
544
- """Returns the embedding vector.
545
-
546
- Arguments
547
- ---------
548
- x : torch.Tensor
549
- Tensor of shape (batch, time, channel).
550
- lengths : torch.Tensor
551
- Corresponding relative lengths of inputs.
552
-
553
- Returns
554
- -------
555
- x : torch.Tensor
556
- Embedding vector.
557
- """
558
- # Minimize transpose for efficiency
559
- x = x.transpose(1, 2)
560
-
561
- xl = []
562
- for layer in self.blocks:
563
- try:
564
- x = layer(x, lengths=lengths)
565
- except TypeError:
566
- x = layer(x)
567
- xl.append(x)
568
-
569
- # Multi-layer feature aggregation
570
- x = torch.cat(xl[1:], dim=1)
571
- x = self.mfa(x)
572
-
573
- # Attentive Statistical Pooling
574
- x = self.asp(x, lengths=lengths)
575
- x = self.asp_bn(x)
576
-
577
- # Final linear transformation
578
- x = self.fc(x)
579
-
580
- x = x.transpose(1, 2)
581
- return x
582
-
583
-
584
- class Classifier(torch.nn.Module):
585
- """This class implements the cosine similarity on the top of features.
586
-
587
- Arguments
588
- ---------
589
- input_size : int
590
- Expected size of input dimension.
591
- device : str
592
- Device used, e.g., "cpu" or "cuda".
593
- lin_blocks : int
594
- Number of linear layers.
595
- lin_neurons : int
596
- Number of neurons in linear layers.
597
- out_neurons : int
598
- Number of classes.
599
-
600
- Example
601
- -------
602
- >>> classify = Classifier(input_size=2, lin_neurons=2, out_neurons=2)
603
- >>> outputs = torch.tensor([ [1., -1.], [-9., 1.], [0.9, 0.1], [0.1, 0.9] ])
604
- >>> outputs = outputs.unsqueeze(1)
605
- >>> cos = classify(outputs)
606
- >>> (cos < -1.0).long().sum()
607
- tensor(0)
608
- >>> (cos > 1.0).long().sum()
609
- tensor(0)
610
- """
611
-
612
- def __init__(
613
- self,
614
- input_size,
615
- device="cpu",
616
- lin_blocks=0,
617
- lin_neurons=192,
618
- out_neurons=1211,
619
- ):
620
- super().__init__()
621
- self.blocks = nn.ModuleList()
622
-
623
- for block_index in range(lin_blocks):
624
- self.blocks.extend(
625
- [
626
- _BatchNorm1d(input_size=input_size),
627
- Linear(input_size=input_size, n_neurons=lin_neurons),
628
- ]
629
- )
630
- input_size = lin_neurons
631
-
632
- # Final Layer
633
- self.weight = nn.Parameter(
634
- torch.FloatTensor(out_neurons, input_size, device=device)
635
- )
636
- nn.init.xavier_uniform_(self.weight)
637
-
638
- def forward(self, x):
639
- """Returns the output probabilities over speakers.
640
-
641
- Arguments
642
- ---------
643
- x : torch.Tensor
644
- Torch tensor.
645
-
646
- Returns
647
- -------
648
- out : torch.Tensor
649
- Output probabilities over speakers.
650
- """
651
- for layer in self.blocks:
652
- x = layer(x)
653
-
654
- # Need to be normalized
655
- x = F.linear(F.normalize(x.squeeze(1)), F.normalize(self.weight))
656
- return x.unsqueeze(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/BigVGAN/__init__.py DELETED
File without changes
indextts/BigVGAN/activations.py DELETED
@@ -1,122 +0,0 @@
1
- # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
2
- # LICENSE is in incl_licenses directory.
3
-
4
- import torch
5
- from torch import nn, pow, sin
6
- from torch.nn import Parameter
7
-
8
-
9
- class Snake(nn.Module):
10
- '''
11
- Implementation of a sine-based periodic activation function
12
- Shape:
13
- - Input: (B, C, T)
14
- - Output: (B, C, T), same shape as the input
15
- Parameters:
16
- - alpha - trainable parameter
17
- References:
18
- - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
19
- https://arxiv.org/abs/2006.08195
20
- Examples:
21
- >>> a1 = snake(256)
22
- >>> x = torch.randn(256)
23
- >>> x = a1(x)
24
- '''
25
-
26
- def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
27
- '''
28
- Initialization.
29
- INPUT:
30
- - in_features: shape of the input
31
- - alpha: trainable parameter
32
- alpha is initialized to 1 by default, higher values = higher-frequency.
33
- alpha will be trained along with the rest of your model.
34
- '''
35
- super(Snake, self).__init__()
36
- self.in_features = in_features
37
-
38
- # initialize alpha
39
- self.alpha_logscale = alpha_logscale
40
- if self.alpha_logscale: # log scale alphas initialized to zeros
41
- self.alpha = Parameter(torch.zeros(in_features) * alpha)
42
- else: # linear scale alphas initialized to ones
43
- self.alpha = Parameter(torch.ones(in_features) * alpha)
44
-
45
- self.alpha.requires_grad = alpha_trainable
46
-
47
- self.no_div_by_zero = 0.000000001
48
-
49
- def forward(self, x):
50
- '''
51
- Forward pass of the function.
52
- Applies the function to the input elementwise.
53
- Snake ∶= x + 1/a * sin^2 (xa)
54
- '''
55
- alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
56
- if self.alpha_logscale:
57
- alpha = torch.exp(alpha)
58
- x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
59
-
60
- return x
61
-
62
-
63
- class SnakeBeta(nn.Module):
64
- '''
65
- A modified Snake function which uses separate parameters for the magnitude of the periodic components
66
- Shape:
67
- - Input: (B, C, T)
68
- - Output: (B, C, T), same shape as the input
69
- Parameters:
70
- - alpha - trainable parameter that controls frequency
71
- - beta - trainable parameter that controls magnitude
72
- References:
73
- - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
74
- https://arxiv.org/abs/2006.08195
75
- Examples:
76
- >>> a1 = snakebeta(256)
77
- >>> x = torch.randn(256)
78
- >>> x = a1(x)
79
- '''
80
-
81
- def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
82
- '''
83
- Initialization.
84
- INPUT:
85
- - in_features: shape of the input
86
- - alpha - trainable parameter that controls frequency
87
- - beta - trainable parameter that controls magnitude
88
- alpha is initialized to 1 by default, higher values = higher-frequency.
89
- beta is initialized to 1 by default, higher values = higher-magnitude.
90
- alpha will be trained along with the rest of your model.
91
- '''
92
- super(SnakeBeta, self).__init__()
93
- self.in_features = in_features
94
-
95
- # initialize alpha
96
- self.alpha_logscale = alpha_logscale
97
- if self.alpha_logscale: # log scale alphas initialized to zeros
98
- self.alpha = Parameter(torch.zeros(in_features) * alpha)
99
- self.beta = Parameter(torch.zeros(in_features) * alpha)
100
- else: # linear scale alphas initialized to ones
101
- self.alpha = Parameter(torch.ones(in_features) * alpha)
102
- self.beta = Parameter(torch.ones(in_features) * alpha)
103
-
104
- self.alpha.requires_grad = alpha_trainable
105
- self.beta.requires_grad = alpha_trainable
106
-
107
- self.no_div_by_zero = 0.000000001
108
-
109
- def forward(self, x):
110
- '''
111
- Forward pass of the function.
112
- Applies the function to the input elementwise.
113
- SnakeBeta ∶= x + 1/b * sin^2 (xa)
114
- '''
115
- alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
116
- beta = self.beta.unsqueeze(0).unsqueeze(-1)
117
- if self.alpha_logscale:
118
- alpha = torch.exp(alpha)
119
- beta = torch.exp(beta)
120
- x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
121
-
122
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/BigVGAN/alias_free_activation/__init__.py DELETED
File without changes
indextts/BigVGAN/alias_free_activation/cuda/.gitignore DELETED
@@ -1 +0,0 @@
1
- /build
 
 
indextts/BigVGAN/alias_free_activation/cuda/__init__.py DELETED
File without changes
indextts/BigVGAN/alias_free_activation/cuda/activation1d.py DELETED
@@ -1,76 +0,0 @@
1
- # Copyright (c) 2024 NVIDIA CORPORATION.
2
- # Licensed under the MIT license.
3
-
4
- import torch
5
- import torch.nn as nn
6
- # load fused CUDA kernel: this enables importing anti_alias_activation_cuda
7
- from indextts.BigVGAN.alias_free_activation.cuda import load
8
- from indextts.BigVGAN.alias_free_activation.torch.resample import DownSample1d, UpSample1d
9
-
10
- anti_alias_activation_cuda = load.load()
11
-
12
-
13
- class FusedAntiAliasActivation(torch.autograd.Function):
14
- """
15
- Assumes filter size 12, replication padding on upsampling/downsampling, and logscale alpha/beta parameters as inputs.
16
- The hyperparameters are hard-coded in the kernel to maximize speed.
17
- NOTE: The fused kenrel is incorrect for Activation1d with different hyperparameters.
18
- """
19
-
20
- @staticmethod
21
- def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
22
- activation_results = anti_alias_activation_cuda.forward(
23
- inputs, up_ftr, down_ftr, alpha, beta
24
- )
25
-
26
- return activation_results
27
-
28
- @staticmethod
29
- def backward(ctx, output_grads):
30
- raise NotImplementedError
31
- return output_grads, None, None
32
-
33
-
34
- class Activation1d(nn.Module):
35
- def __init__(
36
- self,
37
- activation,
38
- up_ratio: int = 2,
39
- down_ratio: int = 2,
40
- up_kernel_size: int = 12,
41
- down_kernel_size: int = 12,
42
- fused: bool = True,
43
- ):
44
- super().__init__()
45
- self.up_ratio = up_ratio
46
- self.down_ratio = down_ratio
47
- self.act = activation
48
- self.upsample = UpSample1d(up_ratio, up_kernel_size)
49
- self.downsample = DownSample1d(down_ratio, down_kernel_size)
50
-
51
- self.fused = fused # Whether to use fused CUDA kernel or not
52
-
53
- def forward(self, x):
54
- if not self.fused:
55
- x = self.upsample(x)
56
- x = self.act(x)
57
- x = self.downsample(x)
58
- return x
59
- else:
60
- if self.act.__class__.__name__ == "Snake":
61
- beta = self.act.alpha.data # Snake uses same params for alpha and beta
62
- else:
63
- beta = (
64
- self.act.beta.data
65
- ) # Snakebeta uses different params for alpha and beta
66
- alpha = self.act.alpha.data
67
- if (
68
- not self.act.alpha_logscale
69
- ): # Exp baked into cuda kernel, cancel it out with a log
70
- alpha = torch.log(alpha)
71
- beta = torch.log(beta)
72
-
73
- x = FusedAntiAliasActivation.apply(
74
- x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta
75
- )
76
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp DELETED
@@ -1,23 +0,0 @@
1
- /* coding=utf-8
2
- * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3
- *
4
- * Licensed under the Apache License, Version 2.0 (the "License");
5
- * you may not use this file except in compliance with the License.
6
- * You may obtain a copy of the License at
7
- *
8
- * http://www.apache.org/licenses/LICENSE-2.0
9
- *
10
- * Unless required by applicable law or agreed to in writing, software
11
- * distributed under the License is distributed on an "AS IS" BASIS,
12
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- * See the License for the specific language governing permissions and
14
- * limitations under the License.
15
- */
16
-
17
- #include <torch/extension.h>
18
-
19
- extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta);
20
-
21
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22
- m.def("forward", &fwd_cuda, "Anti-Alias Activation forward (CUDA)");
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu DELETED
@@ -1,256 +0,0 @@
1
- /* coding=utf-8
2
- * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3
- *
4
- * Licensed under the Apache License, Version 2.0 (the "License");
5
- * you may not use this file except in compliance with the License.
6
- * You may obtain a copy of the License at
7
- *
8
- * http://www.apache.org/licenses/LICENSE-2.0
9
- *
10
- * Unless required by applicable law or agreed to in writing, software
11
- * distributed under the License is distributed on an "AS IS" BASIS,
12
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- * See the License for the specific language governing permissions and
14
- * limitations under the License.
15
- */
16
-
17
- #include <ATen/ATen.h>
18
- #include <cuda.h>
19
- #include <cuda_runtime.h>
20
- #include <cuda_fp16.h>
21
- #include <cuda_profiler_api.h>
22
- #include <ATen/cuda/CUDAContext.h>
23
- #include <torch/extension.h>
24
- #include "type_shim.h"
25
- #include <assert.h>
26
- #include <cfloat>
27
- #include <limits>
28
- #include <stdint.h>
29
- #include <c10/macros/Macros.h>
30
-
31
- namespace
32
- {
33
- // Hard-coded hyperparameters
34
- // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
35
- constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;
36
- constexpr int BUFFER_SIZE = 32;
37
- constexpr int FILTER_SIZE = 12;
38
- constexpr int HALF_FILTER_SIZE = 6;
39
- constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl
40
- constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl
41
- constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl
42
-
43
- template <typename input_t, typename output_t, typename acc_t>
44
- __global__ void anti_alias_activation_forward(
45
- output_t *dst,
46
- const input_t *src,
47
- const acc_t *up_ftr,
48
- const acc_t *down_ftr,
49
- const acc_t *alpha,
50
- const acc_t *beta,
51
- int batch_size,
52
- int channels,
53
- int seq_len)
54
- {
55
- // Up and downsample filters
56
- input_t up_filter[FILTER_SIZE];
57
- input_t down_filter[FILTER_SIZE];
58
-
59
- // Load data from global memory including extra indices reserved for replication paddings
60
- input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0};
61
- input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0};
62
-
63
- // Output stores downsampled output before writing to dst
64
- output_t output[BUFFER_SIZE];
65
-
66
- // blockDim/threadIdx = (128, 1, 1)
67
- // gridDim/blockIdx = (seq_blocks, channels, batches)
68
- int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
69
- int local_offset = threadIdx.x * BUFFER_SIZE;
70
- int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset;
71
-
72
- // intermediate have double the seq_len
73
- int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2;
74
- int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset;
75
-
76
- // Get values needed for replication padding before moving pointer
77
- const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
78
- input_t seq_left_most_value = right_most_pntr[0];
79
- input_t seq_right_most_value = right_most_pntr[seq_len - 1];
80
-
81
- // Move src and dst pointers
82
- src += block_offset + local_offset;
83
- dst += block_offset + local_offset;
84
-
85
- // Alpha and beta values for snake activatons. Applies exp by default
86
- alpha = alpha + blockIdx.y;
87
- beta = beta + blockIdx.y;
88
-
89
- acc_t alpha_val = expf(alpha[0]);
90
- acc_t beta_val = expf(beta[0]);
91
-
92
- #pragma unroll
93
- for (int it = 0; it < FILTER_SIZE; it += 1)
94
- {
95
- up_filter[it] = up_ftr[it];
96
- down_filter[it] = down_ftr[it];
97
- }
98
-
99
- // Apply replication padding for upsampling, matching torch impl
100
- #pragma unroll
101
- for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE; it += 1)
102
- {
103
- int element_index = seq_offset + it; // index for element
104
- if ((element_index < 0) && (element_index >= -UPSAMPLE_REPLICATION_PAD))
105
- {
106
- elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_left_most_value;
107
- }
108
- if ((element_index >= seq_len) && (element_index < seq_len + UPSAMPLE_REPLICATION_PAD))
109
- {
110
- elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_right_most_value;
111
- }
112
- if ((element_index >= 0) && (element_index < seq_len))
113
- {
114
- elements[2 * (HALF_FILTER_SIZE + it)] = 2 * src[it];
115
- }
116
- }
117
-
118
- // Apply upsampling strided convolution and write to intermediates. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT for replication padding of the downsampilng conv later
119
- #pragma unroll
120
- for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1)
121
- {
122
- acc_t acc = 0.0;
123
- int element_index = intermediate_seq_offset + it; // index for intermediate
124
- #pragma unroll
125
- for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
126
- {
127
- if ((element_index + f_idx) >= 0)
128
- {
129
- acc += up_filter[f_idx] * elements[it + f_idx];
130
- }
131
- }
132
- intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] = acc;
133
- }
134
-
135
- // Apply activation function. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT and DOWNSAMPLE_REPLICATION_PAD_RIGHT for replication padding of the downsampilng conv later
136
- double no_div_by_zero = 0.000000001;
137
- #pragma unroll
138
- for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1)
139
- {
140
- acc_t a = sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val);
141
- intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * a * a;
142
- }
143
-
144
- // Apply replication padding before downsampling conv from intermediates
145
- #pragma unroll
146
- for (int it = 0; it < DOWNSAMPLE_REPLICATION_PAD_LEFT; it += 1)
147
- {
148
- intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT];
149
- }
150
- #pragma unroll
151
- for (int it = DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it < DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE + DOWNSAMPLE_REPLICATION_PAD_RIGHT; it += 1)
152
- {
153
- intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE - 1];
154
- }
155
-
156
- // Apply downsample strided convolution (assuming stride=2) from intermediates
157
- #pragma unroll
158
- for (int it = 0; it < BUFFER_SIZE; it += 1)
159
- {
160
- acc_t acc = 0.0;
161
- #pragma unroll
162
- for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
163
- {
164
- // Add constant DOWNSAMPLE_REPLICATION_PAD_RIGHT to match torch implementation
165
- acc += down_filter[f_idx] * intermediates[it * 2 + f_idx + DOWNSAMPLE_REPLICATION_PAD_RIGHT];
166
- }
167
- output[it] = acc;
168
- }
169
-
170
- // Write output to dst
171
- #pragma unroll
172
- for (int it = 0; it < BUFFER_SIZE; it += ELEMENTS_PER_LDG_STG)
173
- {
174
- int element_index = seq_offset + it;
175
- if (element_index < seq_len)
176
- {
177
- dst[it] = output[it];
178
- }
179
- }
180
-
181
- }
182
-
183
- template <typename input_t, typename output_t, typename acc_t>
184
- void dispatch_anti_alias_activation_forward(
185
- output_t *dst,
186
- const input_t *src,
187
- const acc_t *up_ftr,
188
- const acc_t *down_ftr,
189
- const acc_t *alpha,
190
- const acc_t *beta,
191
- int batch_size,
192
- int channels,
193
- int seq_len)
194
- {
195
- if (seq_len == 0)
196
- {
197
- return;
198
- }
199
- else
200
- {
201
- // Use 128 threads per block to maximimize gpu utilization
202
- constexpr int threads_per_block = 128;
203
- constexpr int seq_len_per_block = 4096;
204
- int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block;
205
- dim3 blocks(blocks_per_seq_len, channels, batch_size);
206
- dim3 threads(threads_per_block, 1, 1);
207
-
208
- anti_alias_activation_forward<input_t, output_t, acc_t>
209
- <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, up_ftr, down_ftr, alpha, beta, batch_size, channels, seq_len);
210
- }
211
- }
212
- }
213
-
214
- extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta)
215
- {
216
- // Input is a 3d tensor with dimensions [batches, channels, seq_len]
217
- const int batches = input.size(0);
218
- const int channels = input.size(1);
219
- const int seq_len = input.size(2);
220
-
221
- // Output
222
- auto act_options = input.options().requires_grad(false);
223
-
224
- torch::Tensor anti_alias_activation_results =
225
- torch::empty({batches, channels, seq_len}, act_options);
226
-
227
- using float32 = float;
228
- // The dtype of input is float16, bfloat16, or float32
229
- // The dtype of up_filter, down_filter, alpha, and beta is float32
230
- // printf("input scalar type: %d\n", input.scalar_type());
231
- // printf("up_filter scalar type: %d\n", up_filter.scalar_type());
232
- // printf("down_filter scalar type: %d\n", down_filter.scalar_type());
233
- // printf("alpha scalar type: %d\n", alpha.scalar_type());
234
- // printf("beta scalar type: %d\n", beta.scalar_type());
235
- void *input_ptr = static_cast<void *>(input.data_ptr());
236
- float32 *up_filter_ptr = static_cast<float32 *>(up_filter.data_ptr());
237
- float32 *down_filter_ptr = static_cast<float32 *>(down_filter.data_ptr());
238
- float32 *alpha_ptr = static_cast<float32 *>(alpha.data_ptr());
239
- float32 *beta_ptr = static_cast<float32 *>(beta.data_ptr());
240
- void *anti_alias_activation_results_ptr = static_cast<void *>(anti_alias_activation_results.data_ptr());
241
-
242
- DISPATCH_FLOAT_HALF_AND_BFLOAT(
243
- input.scalar_type(),
244
- "dispatch anti alias activation_forward",
245
- dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float32>(
246
- reinterpret_cast<scalar_t *>(anti_alias_activation_results_ptr),
247
- reinterpret_cast<const scalar_t *>(input_ptr),
248
- reinterpret_cast<const float32 *>(up_filter_ptr),
249
- reinterpret_cast<const float32 *>(down_filter_ptr),
250
- reinterpret_cast<const float32 *>(alpha_ptr),
251
- reinterpret_cast<const float32 *>(beta_ptr),
252
- batches,
253
- channels,
254
- seq_len););
255
- return anti_alias_activation_results;
256
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/BigVGAN/alias_free_activation/cuda/compat.h DELETED
@@ -1,29 +0,0 @@
1
- /* coding=utf-8
2
- * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
- *
4
- * Licensed under the Apache License, Version 2.0 (the "License");
5
- * you may not use this file except in compliance with the License.
6
- * You may obtain a copy of the License at
7
- *
8
- * http://www.apache.org/licenses/LICENSE-2.0
9
- *
10
- * Unless required by applicable law or agreed to in writing, software
11
- * distributed under the License is distributed on an "AS IS" BASIS,
12
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- * See the License for the specific language governing permissions and
14
- * limitations under the License.
15
- */
16
-
17
- /*This code is copied fron NVIDIA apex:
18
- * https://github.com/NVIDIA/apex
19
- * with minor changes. */
20
-
21
- #ifndef TORCH_CHECK
22
- #define TORCH_CHECK AT_CHECK
23
- #endif
24
-
25
- #ifdef VERSION_GE_1_3
26
- #define DATA_PTR data_ptr
27
- #else
28
- #define DATA_PTR data
29
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/BigVGAN/alias_free_activation/cuda/load.py DELETED
@@ -1,121 +0,0 @@
1
- # Copyright (c) 2024 NVIDIA CORPORATION.
2
- # Licensed under the MIT license.
3
-
4
- import os
5
- import pathlib
6
- import subprocess
7
-
8
- from torch.utils import cpp_extension
9
-
10
- """
11
- Setting this param to a list has a problem of generating different compilation commands (with diferent order of architectures) and leading to recompilation of fused kernels.
12
- Set it to empty stringo avoid recompilation and assign arch flags explicity in extra_cuda_cflags below
13
- """
14
- os.environ["TORCH_CUDA_ARCH_LIST"] = ""
15
-
16
-
17
- import re
18
- import shutil
19
- import tempfile
20
-
21
- # 补丁修复:sources 路径含中文字符时,生成 build.ninja 乱码导致编译失败
22
- # 使用临时目录来规避 ninja 编译失败(比如中文路径)
23
- def chinese_path_compile_support(sources, buildpath):
24
- pattern = re.compile(r'[\u4e00-\u9fff]')
25
- if not bool(pattern.search(str(sources[0].resolve()))):
26
- return buildpath # 检测非中文路径跳过
27
- # Create build directory
28
- resolves = [ item.name for item in sources]
29
- ninja_compile_dir = os.path.join(tempfile.gettempdir(), "BigVGAN", "cuda")
30
- os.makedirs(ninja_compile_dir, exist_ok=True)
31
- new_buildpath = os.path.join(ninja_compile_dir, "build")
32
- os.makedirs(new_buildpath, exist_ok=True)
33
- print(f"ninja_buildpath: {new_buildpath}")
34
- # Copy files to directory
35
- sources.clear()
36
- current_dir = os.path.dirname(__file__)
37
- ALLOWED_EXTENSIONS = {'.py', '.cu', '.cpp', '.h'}
38
- for filename in os.listdir(current_dir):
39
- item = pathlib.Path(current_dir).joinpath(filename)
40
- tar_path = pathlib.Path(ninja_compile_dir).joinpath(item.name)
41
- if not item.suffix.lower() in ALLOWED_EXTENSIONS:continue
42
- pathlib.Path(shutil.copy2(item, tar_path))
43
- if tar_path.name in resolves:sources.append(tar_path)
44
- return new_buildpath
45
-
46
-
47
-
48
- def load():
49
- # Check if cuda 11 is installed for compute capability 8.0
50
- cc_flag = []
51
- _, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
52
- if int(bare_metal_major) >= 11:
53
- cc_flag.append("-gencode")
54
- cc_flag.append("arch=compute_80,code=sm_80")
55
-
56
- # Build path
57
- srcpath = pathlib.Path(__file__).parent.absolute()
58
- buildpath = srcpath / "build"
59
- _create_build_dir(buildpath)
60
-
61
- # Helper function to build the kernels.
62
- def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
63
- return cpp_extension.load(
64
- name=name,
65
- sources=sources,
66
- build_directory=buildpath,
67
- extra_cflags=[
68
- "-O3",
69
- ],
70
- extra_cuda_cflags=[
71
- "-O3",
72
- "-gencode",
73
- "arch=compute_70,code=sm_70",
74
- "--use_fast_math",
75
- ]
76
- + extra_cuda_flags
77
- + cc_flag,
78
- verbose=True,
79
- )
80
-
81
- extra_cuda_flags = [
82
- "-U__CUDA_NO_HALF_OPERATORS__",
83
- "-U__CUDA_NO_HALF_CONVERSIONS__",
84
- "--expt-relaxed-constexpr",
85
- "--expt-extended-lambda",
86
- ]
87
-
88
- sources = [
89
- srcpath / "anti_alias_activation.cpp",
90
- srcpath / "anti_alias_activation_cuda.cu",
91
- ]
92
-
93
- # 兼容方案:ninja 特殊字符路径编译支持处理(比如中文路径)
94
- buildpath = chinese_path_compile_support(sources, buildpath)
95
-
96
- anti_alias_activation_cuda = _cpp_extention_load_helper(
97
- "anti_alias_activation_cuda", sources, extra_cuda_flags
98
- )
99
-
100
- return anti_alias_activation_cuda
101
-
102
-
103
- def _get_cuda_bare_metal_version(cuda_dir):
104
- raw_output = subprocess.check_output(
105
- [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
106
- )
107
- output = raw_output.split()
108
- release_idx = output.index("release") + 1
109
- release = output[release_idx].split(".")
110
- bare_metal_major = release[0]
111
- bare_metal_minor = release[1][0]
112
-
113
- return raw_output, bare_metal_major, bare_metal_minor
114
-
115
-
116
- def _create_build_dir(buildpath):
117
- try:
118
- os.mkdir(buildpath)
119
- except OSError:
120
- if not os.path.isdir(buildpath):
121
- print(f"Creation of the build directory {buildpath} failed")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/BigVGAN/alias_free_activation/cuda/type_shim.h DELETED
@@ -1,92 +0,0 @@
1
- /* coding=utf-8
2
- * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
- *
4
- * Licensed under the Apache License, Version 2.0 (the "License");
5
- * you may not use this file except in compliance with the License.
6
- * You may obtain a copy of the License at
7
- *
8
- * http://www.apache.org/licenses/LICENSE-2.0
9
- *
10
- * Unless required by applicable law or agreed to in writing, software
11
- * distributed under the License is distributed on an "AS IS" BASIS,
12
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- * See the License for the specific language governing permissions and
14
- * limitations under the License.
15
- */
16
-
17
- #include <ATen/ATen.h>
18
- #include "compat.h"
19
-
20
- #define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
21
- switch (TYPE) \
22
- { \
23
- case at::ScalarType::Float: \
24
- { \
25
- using scalar_t = float; \
26
- __VA_ARGS__; \
27
- break; \
28
- } \
29
- case at::ScalarType::Half: \
30
- { \
31
- using scalar_t = at::Half; \
32
- __VA_ARGS__; \
33
- break; \
34
- } \
35
- case at::ScalarType::BFloat16: \
36
- { \
37
- using scalar_t = at::BFloat16; \
38
- __VA_ARGS__; \
39
- break; \
40
- } \
41
- default: \
42
- AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
43
- }
44
-
45
- #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
46
- switch (TYPEIN) \
47
- { \
48
- case at::ScalarType::Float: \
49
- { \
50
- using scalar_t_in = float; \
51
- switch (TYPEOUT) \
52
- { \
53
- case at::ScalarType::Float: \
54
- { \
55
- using scalar_t_out = float; \
56
- __VA_ARGS__; \
57
- break; \
58
- } \
59
- case at::ScalarType::Half: \
60
- { \
61
- using scalar_t_out = at::Half; \
62
- __VA_ARGS__; \
63
- break; \
64
- } \
65
- case at::ScalarType::BFloat16: \
66
- { \
67
- using scalar_t_out = at::BFloat16; \
68
- __VA_ARGS__; \
69
- break; \
70
- } \
71
- default: \
72
- AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
73
- } \
74
- break; \
75
- } \
76
- case at::ScalarType::Half: \
77
- { \
78
- using scalar_t_in = at::Half; \
79
- using scalar_t_out = at::Half; \
80
- __VA_ARGS__; \
81
- break; \
82
- } \
83
- case at::ScalarType::BFloat16: \
84
- { \
85
- using scalar_t_in = at::BFloat16; \
86
- using scalar_t_out = at::BFloat16; \
87
- __VA_ARGS__; \
88
- break; \
89
- } \
90
- default: \
91
- AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
92
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/BigVGAN/alias_free_activation/torch/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
- # LICENSE is in incl_licenses directory.
3
-
4
- from .act import *
5
- from .filter import *
6
- from .resample import *
 
 
 
 
 
 
 
indextts/BigVGAN/alias_free_activation/torch/act.py DELETED
@@ -1,31 +0,0 @@
1
- # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
- # LICENSE is in incl_licenses directory.
3
-
4
- import torch.nn as nn
5
-
6
- from .resample import DownSample1d, UpSample1d
7
-
8
-
9
- class Activation1d(nn.Module):
10
- def __init__(
11
- self,
12
- activation,
13
- up_ratio: int = 2,
14
- down_ratio: int = 2,
15
- up_kernel_size: int = 12,
16
- down_kernel_size: int = 12,
17
- ):
18
- super().__init__()
19
- self.up_ratio = up_ratio
20
- self.down_ratio = down_ratio
21
- self.act = activation
22
- self.upsample = UpSample1d(up_ratio, up_kernel_size)
23
- self.downsample = DownSample1d(down_ratio, down_kernel_size)
24
-
25
- # x: [B,C,T]
26
- def forward(self, x):
27
- x = self.upsample(x)
28
- x = self.act(x)
29
- x = self.downsample(x)
30
-
31
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/BigVGAN/alias_free_activation/torch/filter.py DELETED
@@ -1,102 +0,0 @@
1
- # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
- # LICENSE is in incl_licenses directory.
3
-
4
- import math
5
-
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
-
10
- if "sinc" in dir(torch):
11
- sinc = torch.sinc
12
- else:
13
- # This code is adopted from adefossez's julius.core.sinc under the MIT License
14
- # https://adefossez.github.io/julius/julius/core.html
15
- # LICENSE is in incl_licenses directory.
16
- def sinc(x: torch.Tensor):
17
- """
18
- Implementation of sinc, i.e. sin(pi * x) / (pi * x)
19
- __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
20
- """
21
- return torch.where(
22
- x == 0,
23
- torch.tensor(1.0, device=x.device, dtype=x.dtype),
24
- torch.sin(math.pi * x) / math.pi / x,
25
- )
26
-
27
-
28
- # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
29
- # https://adefossez.github.io/julius/julius/lowpass.html
30
- # LICENSE is in incl_licenses directory.
31
- def kaiser_sinc_filter1d(
32
- cutoff, half_width, kernel_size
33
- ): # return filter [1,1,kernel_size]
34
- even = kernel_size % 2 == 0
35
- half_size = kernel_size // 2
36
-
37
- # For kaiser window
38
- delta_f = 4 * half_width
39
- A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
40
- if A > 50.0:
41
- beta = 0.1102 * (A - 8.7)
42
- elif A >= 21.0:
43
- beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
44
- else:
45
- beta = 0.0
46
- window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
47
-
48
- # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
49
- if even:
50
- time = torch.arange(-half_size, half_size) + 0.5
51
- else:
52
- time = torch.arange(kernel_size) - half_size
53
- if cutoff == 0:
54
- filter_ = torch.zeros_like(time)
55
- else:
56
- filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
57
- """
58
- Normalize filter to have sum = 1, otherwise we will have a small leakage of the constant component in the input signal.
59
- """
60
- filter_ /= filter_.sum()
61
- filter = filter_.view(1, 1, kernel_size)
62
-
63
- return filter
64
-
65
-
66
- class LowPassFilter1d(nn.Module):
67
- def __init__(
68
- self,
69
- cutoff=0.5,
70
- half_width=0.6,
71
- stride: int = 1,
72
- padding: bool = True,
73
- padding_mode: str = "replicate",
74
- kernel_size: int = 12,
75
- ):
76
- """
77
- kernel_size should be even number for stylegan3 setup, in this implementation, odd number is also possible.
78
- """
79
- super().__init__()
80
- if cutoff < -0.0:
81
- raise ValueError("Minimum cutoff must be larger than zero.")
82
- if cutoff > 0.5:
83
- raise ValueError("A cutoff above 0.5 does not make sense.")
84
- self.kernel_size = kernel_size
85
- self.even = kernel_size % 2 == 0
86
- self.pad_left = kernel_size // 2 - int(self.even)
87
- self.pad_right = kernel_size // 2
88
- self.stride = stride
89
- self.padding = padding
90
- self.padding_mode = padding_mode
91
- filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
92
- self.register_buffer("filter", filter)
93
-
94
- # Input [B, C, T]
95
- def forward(self, x):
96
- _, C, _ = x.shape
97
-
98
- if self.padding:
99
- x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
100
- out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
101
-
102
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/BigVGAN/alias_free_activation/torch/resample.py DELETED
@@ -1,58 +0,0 @@
1
- # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
- # LICENSE is in incl_licenses directory.
3
-
4
- import torch.nn as nn
5
- from torch.nn import functional as F
6
-
7
- from .filter import LowPassFilter1d, kaiser_sinc_filter1d
8
-
9
-
10
- class UpSample1d(nn.Module):
11
- def __init__(self, ratio=2, kernel_size=None):
12
- super().__init__()
13
- self.ratio = ratio
14
- self.kernel_size = (
15
- int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
16
- )
17
- self.stride = ratio
18
- self.pad = self.kernel_size // ratio - 1
19
- self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
20
- self.pad_right = (
21
- self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
22
- )
23
- filter = kaiser_sinc_filter1d(
24
- cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
25
- )
26
- self.register_buffer("filter", filter)
27
-
28
- # x: [B, C, T]
29
- def forward(self, x):
30
- _, C, _ = x.shape
31
-
32
- x = F.pad(x, (self.pad, self.pad), mode="replicate")
33
- x = self.ratio * F.conv_transpose1d(
34
- x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
35
- )
36
- x = x[..., self.pad_left : -self.pad_right]
37
-
38
- return x
39
-
40
-
41
- class DownSample1d(nn.Module):
42
- def __init__(self, ratio=2, kernel_size=None):
43
- super().__init__()
44
- self.ratio = ratio
45
- self.kernel_size = (
46
- int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
47
- )
48
- self.lowpass = LowPassFilter1d(
49
- cutoff=0.5 / ratio,
50
- half_width=0.6 / ratio,
51
- stride=ratio,
52
- kernel_size=self.kernel_size,
53
- )
54
-
55
- def forward(self, x):
56
- xx = self.lowpass(x)
57
-
58
- return xx
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/BigVGAN/alias_free_torch/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
- # LICENSE is in incl_licenses directory.
3
-
4
- from .act import *
5
- from .filter import *
6
- from .resample import *
 
 
 
 
 
 
 
indextts/BigVGAN/alias_free_torch/act.py DELETED
@@ -1,29 +0,0 @@
1
- # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
- # LICENSE is in incl_licenses directory.
3
-
4
- import torch.nn as nn
5
-
6
- from .resample import DownSample1d, UpSample1d
7
-
8
-
9
- class Activation1d(nn.Module):
10
- def __init__(self,
11
- activation,
12
- up_ratio: int = 2,
13
- down_ratio: int = 2,
14
- up_kernel_size: int = 12,
15
- down_kernel_size: int = 12):
16
- super().__init__()
17
- self.up_ratio = up_ratio
18
- self.down_ratio = down_ratio
19
- self.act = activation
20
- self.upsample = UpSample1d(up_ratio, up_kernel_size)
21
- self.downsample = DownSample1d(down_ratio, down_kernel_size)
22
-
23
- # x: [B,C,T]
24
- def forward(self, x):
25
- x = self.upsample(x)
26
- x = self.act(x)
27
- x = self.downsample(x)
28
-
29
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/BigVGAN/alias_free_torch/filter.py DELETED
@@ -1,96 +0,0 @@
1
- # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
- # LICENSE is in incl_licenses directory.
3
-
4
- import math
5
-
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
-
10
- if 'sinc' in dir(torch):
11
- sinc = torch.sinc
12
- else:
13
- # This code is adopted from adefossez's julius.core.sinc under the MIT License
14
- # https://adefossez.github.io/julius/julius/core.html
15
- # LICENSE is in incl_licenses directory.
16
- def sinc(x: torch.Tensor):
17
- """
18
- Implementation of sinc, i.e. sin(pi * x) / (pi * x)
19
- __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
20
- """
21
- return torch.where(x == 0,
22
- torch.tensor(1., device=x.device, dtype=x.dtype),
23
- torch.sin(math.pi * x) / math.pi / x)
24
-
25
-
26
- # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
27
- # https://adefossez.github.io/julius/julius/lowpass.html
28
- # LICENSE is in incl_licenses directory.
29
- def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
30
- even = (kernel_size % 2 == 0)
31
- half_size = kernel_size // 2
32
-
33
- #For kaiser window
34
- delta_f = 4 * half_width
35
- A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
36
- if A > 50.:
37
- beta = 0.1102 * (A - 8.7)
38
- elif A >= 21.:
39
- beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
40
- else:
41
- beta = 0.
42
- window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
43
-
44
- # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
45
- if even:
46
- time = (torch.arange(-half_size, half_size) + 0.5)
47
- else:
48
- time = torch.arange(kernel_size) - half_size
49
- if cutoff == 0:
50
- filter_ = torch.zeros_like(time)
51
- else:
52
- filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
53
- # Normalize filter to have sum = 1, otherwise we will have a small leakage
54
- # of the constant component in the input signal.
55
- filter_ /= filter_.sum()
56
- filter = filter_.view(1, 1, kernel_size)
57
-
58
- return filter
59
-
60
-
61
- class LowPassFilter1d(nn.Module):
62
- def __init__(self,
63
- cutoff=0.5,
64
- half_width=0.6,
65
- stride: int = 1,
66
- padding: bool = True,
67
- padding_mode: str = 'replicate',
68
- kernel_size: int = 12):
69
- # kernel_size should be even number for stylegan3 setup,
70
- # in this implementation, odd number is also possible.
71
- super().__init__()
72
- if cutoff < -0.:
73
- raise ValueError("Minimum cutoff must be larger than zero.")
74
- if cutoff > 0.5:
75
- raise ValueError("A cutoff above 0.5 does not make sense.")
76
- self.kernel_size = kernel_size
77
- self.even = (kernel_size % 2 == 0)
78
- self.pad_left = kernel_size // 2 - int(self.even)
79
- self.pad_right = kernel_size // 2
80
- self.stride = stride
81
- self.padding = padding
82
- self.padding_mode = padding_mode
83
- filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
84
- self.register_buffer("filter", filter)
85
-
86
- #input [B, C, T]
87
- def forward(self, x):
88
- _, C, _ = x.shape
89
-
90
- if self.padding:
91
- x = F.pad(x, (self.pad_left, self.pad_right),
92
- mode=self.padding_mode)
93
- out = F.conv1d(x, self.filter.expand(C, -1, -1),
94
- stride=self.stride, groups=C)
95
-
96
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/BigVGAN/alias_free_torch/resample.py DELETED
@@ -1,49 +0,0 @@
1
- # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
- # LICENSE is in incl_licenses directory.
3
-
4
- import torch.nn as nn
5
- from torch.nn import functional as F
6
-
7
- from .filter import LowPassFilter1d, kaiser_sinc_filter1d
8
-
9
-
10
- class UpSample1d(nn.Module):
11
- def __init__(self, ratio=2, kernel_size=None):
12
- super().__init__()
13
- self.ratio = ratio
14
- self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
15
- self.stride = ratio
16
- self.pad = self.kernel_size // ratio - 1
17
- self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
18
- self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
19
- filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
20
- half_width=0.6 / ratio,
21
- kernel_size=self.kernel_size)
22
- self.register_buffer("filter", filter)
23
-
24
- # x: [B, C, T]
25
- def forward(self, x):
26
- _, C, _ = x.shape
27
-
28
- x = F.pad(x, (self.pad, self.pad), mode='replicate')
29
- x = self.ratio * F.conv_transpose1d(
30
- x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
31
- x = x[..., self.pad_left:-self.pad_right]
32
-
33
- return x
34
-
35
-
36
- class DownSample1d(nn.Module):
37
- def __init__(self, ratio=2, kernel_size=None):
38
- super().__init__()
39
- self.ratio = ratio
40
- self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
41
- self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
42
- half_width=0.6 / ratio,
43
- stride=ratio,
44
- kernel_size=self.kernel_size)
45
-
46
- def forward(self, x):
47
- xx = self.lowpass(x)
48
-
49
- return xx
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/BigVGAN/bigvgan.py DELETED
@@ -1,534 +0,0 @@
1
- # Copyright (c) 2024 NVIDIA CORPORATION.
2
- # Licensed under the MIT license.
3
-
4
- # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
5
- # LICENSE is in incl_licenses directory.
6
-
7
- import json
8
- import os
9
- from pathlib import Path
10
- from typing import Dict, Optional, Union
11
-
12
- import torch
13
- import torch.nn as nn
14
- from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
15
- from torch.nn import Conv1d, ConvTranspose1d
16
- from torch.nn.utils import remove_weight_norm, weight_norm
17
-
18
- import indextts.BigVGAN.activations as activations
19
- from indextts.BigVGAN.alias_free_activation.torch.act import \
20
- Activation1d as TorchActivation1d
21
- from indextts.BigVGAN.ECAPA_TDNN import ECAPA_TDNN
22
- from indextts.BigVGAN.env import AttrDict
23
- from indextts.BigVGAN.utils import get_padding, init_weights
24
-
25
-
26
- def load_hparams_from_json(path) -> AttrDict:
27
- with open(path) as f:
28
- data = f.read()
29
- return AttrDict(json.loads(data))
30
-
31
-
32
- class AMPBlock1(torch.nn.Module):
33
- """
34
- AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
35
- AMPBlock1 has additional self.convs2 that contains additional Conv1d layers with a fixed dilation=1 followed by each layer in self.convs1
36
-
37
- Args:
38
- h (AttrDict): Hyperparameters.
39
- channels (int): Number of convolution channels.
40
- kernel_size (int): Size of the convolution kernel. Default is 3.
41
- dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
42
- activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
43
- """
44
-
45
- def __init__(
46
- self,
47
- h: AttrDict,
48
- channels: int,
49
- kernel_size: int = 3,
50
- dilation: tuple = (1, 3, 5),
51
- activation: str = None,
52
- ):
53
- super().__init__()
54
-
55
- self.h = h
56
-
57
- self.convs1 = nn.ModuleList(
58
- [
59
- weight_norm(
60
- Conv1d(
61
- channels,
62
- channels,
63
- kernel_size,
64
- stride=1,
65
- dilation=d,
66
- padding=get_padding(kernel_size, d),
67
- )
68
- )
69
- for d in dilation
70
- ]
71
- )
72
- self.convs1.apply(init_weights)
73
-
74
- self.convs2 = nn.ModuleList(
75
- [
76
- weight_norm(
77
- Conv1d(
78
- channels,
79
- channels,
80
- kernel_size,
81
- stride=1,
82
- dilation=1,
83
- padding=get_padding(kernel_size, 1),
84
- )
85
- )
86
- for _ in range(len(dilation))
87
- ]
88
- )
89
- self.convs2.apply(init_weights)
90
-
91
- self.num_layers = len(self.convs1) + len(
92
- self.convs2
93
- ) # Total number of conv layers
94
-
95
- # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
96
- if self.h.get("use_cuda_kernel", False):
97
- from alias_free_activation.cuda.activation1d import \
98
- Activation1d as CudaActivation1d
99
-
100
- Activation1d = CudaActivation1d
101
- else:
102
- Activation1d = TorchActivation1d
103
-
104
- # Activation functions
105
- if activation == "snake":
106
- self.activations = nn.ModuleList(
107
- [
108
- Activation1d(
109
- activation=activations.Snake(
110
- channels, alpha_logscale=h.snake_logscale
111
- )
112
- )
113
- for _ in range(self.num_layers)
114
- ]
115
- )
116
- elif activation == "snakebeta":
117
- self.activations = nn.ModuleList(
118
- [
119
- Activation1d(
120
- activation=activations.SnakeBeta(
121
- channels, alpha_logscale=h.snake_logscale
122
- )
123
- )
124
- for _ in range(self.num_layers)
125
- ]
126
- )
127
- else:
128
- raise NotImplementedError(
129
- "activation incorrectly specified. check the config file and look for 'activation'."
130
- )
131
-
132
- def forward(self, x):
133
- acts1, acts2 = self.activations[::2], self.activations[1::2]
134
- for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
135
- xt = a1(x)
136
- xt = c1(xt)
137
- xt = a2(xt)
138
- xt = c2(xt)
139
- x = xt + x
140
-
141
- return x
142
-
143
- def remove_weight_norm(self):
144
- for l in self.convs1:
145
- remove_weight_norm(l)
146
- for l in self.convs2:
147
- remove_weight_norm(l)
148
-
149
-
150
- class AMPBlock2(torch.nn.Module):
151
- """
152
- AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
153
- Unlike AMPBlock1, AMPBlock2 does not contain extra Conv1d layers with fixed dilation=1
154
-
155
- Args:
156
- h (AttrDict): Hyperparameters.
157
- channels (int): Number of convolution channels.
158
- kernel_size (int): Size of the convolution kernel. Default is 3.
159
- dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
160
- activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
161
- """
162
-
163
- def __init__(
164
- self,
165
- h: AttrDict,
166
- channels: int,
167
- kernel_size: int = 3,
168
- dilation: tuple = (1, 3, 5),
169
- activation: str = None,
170
- ):
171
- super().__init__()
172
-
173
- self.h = h
174
-
175
- self.convs = nn.ModuleList(
176
- [
177
- weight_norm(
178
- Conv1d(
179
- channels,
180
- channels,
181
- kernel_size,
182
- stride=1,
183
- dilation=d,
184
- padding=get_padding(kernel_size, d),
185
- )
186
- )
187
- for d in dilation
188
- ]
189
- )
190
- self.convs.apply(init_weights)
191
-
192
- self.num_layers = len(self.convs) # Total number of conv layers
193
-
194
- # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
195
- if self.h.get("use_cuda_kernel", False):
196
- from alias_free_activation.cuda.activation1d import \
197
- Activation1d as CudaActivation1d
198
-
199
- Activation1d = CudaActivation1d
200
- else:
201
- Activation1d = TorchActivation1d
202
-
203
- # Activation functions
204
- if activation == "snake":
205
- self.activations = nn.ModuleList(
206
- [
207
- Activation1d(
208
- activation=activations.Snake(
209
- channels, alpha_logscale=h.snake_logscale
210
- )
211
- )
212
- for _ in range(self.num_layers)
213
- ]
214
- )
215
- elif activation == "snakebeta":
216
- self.activations = nn.ModuleList(
217
- [
218
- Activation1d(
219
- activation=activations.SnakeBeta(
220
- channels, alpha_logscale=h.snake_logscale
221
- )
222
- )
223
- for _ in range(self.num_layers)
224
- ]
225
- )
226
- else:
227
- raise NotImplementedError(
228
- "activation incorrectly specified. check the config file and look for 'activation'."
229
- )
230
-
231
- def forward(self, x):
232
- for c, a in zip(self.convs, self.activations):
233
- xt = a(x)
234
- xt = c(xt)
235
- x = xt + x
236
- return x
237
-
238
- def remove_weight_norm(self):
239
- for l in self.convs:
240
- remove_weight_norm(l)
241
-
242
-
243
- '''
244
- PyTorchModelHubMixin,
245
- library_name="bigvgan",
246
- repo_url="https://github.com/NVIDIA/BigVGAN",
247
- docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md",
248
- pipeline_tag="audio-to-audio",
249
- license="mit",
250
- tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"],
251
- '''
252
-
253
-
254
- class BigVGAN(
255
- torch.nn.Module,
256
- ):
257
- """
258
- BigVGAN is a neural vocoder model that applies anti-aliased periodic activation for residual blocks (resblocks).
259
- New in BigVGAN-v2: it can optionally use optimized CUDA kernels for AMP (anti-aliased multi-periodicity) blocks.
260
-
261
- Args:
262
- h (AttrDict): Hyperparameters.
263
- use_cuda_kernel (bool): If set to True, loads optimized CUDA kernels for AMP. This should be used for inference only, as training is not supported with CUDA kernels.
264
-
265
- Note:
266
- - The `use_cuda_kernel` parameter should be used for inference only, as training with CUDA kernels is not supported.
267
- - Ensure that the activation function is correctly specified in the hyperparameters (h.activation).
268
- """
269
-
270
- def __init__(self, h: AttrDict, use_cuda_kernel: bool = False):
271
- super().__init__()
272
- self.h = h
273
- self.h["use_cuda_kernel"] = use_cuda_kernel
274
-
275
- # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
276
- if self.h.get("use_cuda_kernel", False):
277
- from alias_free_activation.cuda.activation1d import \
278
- Activation1d as CudaActivation1d
279
-
280
- Activation1d = CudaActivation1d
281
- else:
282
- Activation1d = TorchActivation1d
283
-
284
- self.num_kernels = len(h.resblock_kernel_sizes)
285
- self.num_upsamples = len(h.upsample_rates)
286
-
287
- self.feat_upsample = h.feat_upsample
288
- self.cond_in_each_up_layer = h.cond_d_vector_in_each_upsampling_layer
289
-
290
- # Pre-conv
291
- self.conv_pre = weight_norm(
292
- Conv1d(h.gpt_dim, h.upsample_initial_channel, 7, 1, padding=3)
293
- )
294
-
295
- # Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
296
- if h.resblock == "1":
297
- resblock_class = AMPBlock1
298
- elif h.resblock == "2":
299
- resblock_class = AMPBlock2
300
- else:
301
- raise ValueError(
302
- f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}"
303
- )
304
-
305
- # Transposed conv-based upsamplers. does not apply anti-aliasing
306
- self.ups = nn.ModuleList()
307
- for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
308
- self.ups.append(
309
- nn.ModuleList(
310
- [
311
- weight_norm(
312
- ConvTranspose1d(
313
- h.upsample_initial_channel // (2**i),
314
- h.upsample_initial_channel // (2 ** (i + 1)),
315
- k,
316
- u,
317
- padding=(k - u) // 2,
318
- )
319
- )
320
- ]
321
- )
322
- )
323
-
324
- # Residual blocks using anti-aliased multi-periodicity composition modules (AMP)
325
- self.resblocks = nn.ModuleList()
326
- for i in range(len(self.ups)):
327
- ch = h.upsample_initial_channel // (2 ** (i + 1))
328
- for j, (k, d) in enumerate(
329
- zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
330
- ):
331
- self.resblocks.append(
332
- resblock_class(h, ch, k, d, activation=h.activation)
333
- )
334
-
335
- # Post-conv
336
- activation_post = (
337
- activations.Snake(ch, alpha_logscale=h.snake_logscale)
338
- if h.activation == "snake"
339
- else (
340
- activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
341
- if h.activation == "snakebeta"
342
- else None
343
- )
344
- )
345
- if activation_post is None:
346
- raise NotImplementedError(
347
- "activation incorrectly specified. check the config file and look for 'activation'."
348
- )
349
-
350
- self.activation_post = Activation1d(activation=activation_post)
351
-
352
- # Whether to use bias for the final conv_post. Default to True for backward compatibility
353
- self.use_bias_at_final = h.get("use_bias_at_final", True)
354
- self.conv_post = weight_norm(
355
- Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)
356
- )
357
-
358
- # Weight initialization
359
- for i in range(len(self.ups)):
360
- self.ups[i].apply(init_weights)
361
- self.conv_post.apply(init_weights)
362
-
363
- # Final tanh activation. Defaults to True for backward compatibility
364
- self.use_tanh_at_final = h.get("use_tanh_at_final", True)
365
-
366
- self.speaker_encoder = ECAPA_TDNN(h.num_mels, lin_neurons=h.speaker_embedding_dim)
367
- self.cond_layer = nn.Conv1d(h.speaker_embedding_dim, h.upsample_initial_channel, 1)
368
- if self.cond_in_each_up_layer:
369
- self.conds = nn.ModuleList()
370
- for i in range(len(self.ups)):
371
- ch = h.upsample_initial_channel // (2 ** (i + 1))
372
- self.conds.append(nn.Conv1d(h.speaker_embedding_dim, ch, 1))
373
-
374
- def forward(self, x, mel_refer, lens=None):
375
- # Speaker reference
376
- speaker_embedding = self.speaker_encoder(mel_refer, lens)
377
- n_batch = x.size(0)
378
- contrastive_loss = None
379
- if n_batch * 2 == speaker_embedding.size(0):
380
- spe_emb_chunk1, spe_emb_chunk2 = speaker_embedding[:n_batch, :, :], speaker_embedding[n_batch:, :, :]
381
- contrastive_loss = self.cal_clip_loss(spe_emb_chunk1.squeeze(1), spe_emb_chunk2.squeeze(1),
382
- self.logit_scale.exp())
383
-
384
- speaker_embedding = speaker_embedding[:n_batch, :, :]
385
- speaker_embedding = speaker_embedding.transpose(1, 2)
386
-
387
- # upsample feat
388
- if self.feat_upsample:
389
- x = torch.nn.functional.interpolate(
390
- x.transpose(1, 2),
391
- scale_factor=[4],
392
- mode="linear",
393
- ).squeeze(1)
394
- else:
395
- x = x.transpose(1, 2)
396
-
397
- # BigVGAN
398
- # Pre-conv
399
- x = self.conv_pre(x)
400
- x = x + self.cond_layer(speaker_embedding)
401
-
402
- for i in range(self.num_upsamples):
403
- # Upsampling
404
- for i_up in range(len(self.ups[i])):
405
- x = self.ups[i][i_up](x)
406
-
407
- if self.cond_in_each_up_layer:
408
- x = x + self.conds[i](speaker_embedding)
409
-
410
- # AMP blocks
411
- xs = None
412
- for j in range(self.num_kernels):
413
- if xs is None:
414
- xs = self.resblocks[i * self.num_kernels + j](x)
415
- else:
416
- xs += self.resblocks[i * self.num_kernels + j](x)
417
- x = xs / self.num_kernels
418
-
419
- # Post-conv
420
- x = self.activation_post(x)
421
- x = self.conv_post(x)
422
- # Final tanh activation
423
- if self.use_tanh_at_final:
424
- x = torch.tanh(x)
425
- else:
426
- x = torch.clamp(x, min=-1.0, max=1.0) # Bound the output to [-1, 1]
427
-
428
- return x, contrastive_loss
429
-
430
- def remove_weight_norm(self):
431
- try:
432
- print("Removing weight norm...")
433
- for l in self.ups:
434
- for l_i in l:
435
- remove_weight_norm(l_i)
436
- for l in self.resblocks:
437
- l.remove_weight_norm()
438
- remove_weight_norm(self.conv_pre)
439
- remove_weight_norm(self.conv_post)
440
- except ValueError:
441
- print("[INFO] Model already removed weight norm. Skipping!")
442
- pass
443
-
444
- # Additional methods for huggingface_hub support
445
- def _save_pretrained(self, save_directory: Path) -> None:
446
- """Save weights and config.json from a Pytorch model to a local directory."""
447
-
448
- model_path = save_directory / "bigvgan_generator.pt"
449
- torch.save({"generator": self.state_dict()}, model_path)
450
-
451
- config_path = save_directory / "config.json"
452
- with open(config_path, "w") as config_file:
453
- json.dump(self.h, config_file, indent=4)
454
-
455
- @classmethod
456
- def _from_pretrained(
457
- cls,
458
- *,
459
- model_id: str,
460
- revision: str,
461
- cache_dir: str,
462
- force_download: bool,
463
- proxies: Optional[Dict],
464
- resume_download: bool,
465
- local_files_only: bool,
466
- token: Union[str, bool, None],
467
- map_location: str = "cpu", # Additional argument
468
- strict: bool = False, # Additional argument
469
- use_cuda_kernel: bool = False,
470
- **model_kwargs,
471
- ):
472
- """Load Pytorch pretrained weights and return the loaded model."""
473
-
474
- # Download and load hyperparameters (h) used by BigVGAN
475
- if os.path.isdir(model_id):
476
- print("Loading config.json from local directory")
477
- config_file = os.path.join(model_id, "config.json")
478
- else:
479
- config_file = hf_hub_download(
480
- repo_id=model_id,
481
- filename="config.json",
482
- revision=revision,
483
- cache_dir=cache_dir,
484
- force_download=force_download,
485
- proxies=proxies,
486
- resume_download=resume_download,
487
- token=token,
488
- local_files_only=local_files_only,
489
- )
490
- h = load_hparams_from_json(config_file)
491
-
492
- # instantiate BigVGAN using h
493
- if use_cuda_kernel:
494
- print(
495
- f"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
496
- )
497
- print(
498
- f"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
499
- )
500
- print(
501
- f"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
502
- )
503
- model = cls(h, use_cuda_kernel=use_cuda_kernel)
504
-
505
- # Download and load pretrained generator weight
506
- if os.path.isdir(model_id):
507
- print("Loading weights from local directory")
508
- model_file = os.path.join(model_id, "bigvgan_generator.pt")
509
- else:
510
- print(f"Loading weights from {model_id}")
511
- model_file = hf_hub_download(
512
- repo_id=model_id,
513
- filename="bigvgan_generator.pt",
514
- revision=revision,
515
- cache_dir=cache_dir,
516
- force_download=force_download,
517
- proxies=proxies,
518
- resume_download=resume_download,
519
- token=token,
520
- local_files_only=local_files_only,
521
- )
522
-
523
- checkpoint_dict = torch.load(model_file, map_location=map_location)
524
-
525
- try:
526
- model.load_state_dict(checkpoint_dict["generator"])
527
- except RuntimeError:
528
- print(
529
- f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
530
- )
531
- model.remove_weight_norm()
532
- model.load_state_dict(checkpoint_dict["generator"])
533
-
534
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/BigVGAN/models.py DELETED
@@ -1,451 +0,0 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION.
2
- # Licensed under the MIT license.
3
-
4
- # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
5
- # LICENSE is in incl_licenses directory.
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
- from torch.nn import Conv1d, Conv2d, ConvTranspose1d
10
- from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
11
-
12
- import indextts.BigVGAN.activations as activations
13
-
14
- from indextts.BigVGAN.ECAPA_TDNN import ECAPA_TDNN
15
- from indextts.BigVGAN.utils import get_padding, init_weights
16
-
17
- LRELU_SLOPE = 0.1
18
-
19
-
20
- class AMPBlock1(torch.nn.Module):
21
- def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
22
- super(AMPBlock1, self).__init__()
23
- self.h = h
24
-
25
- self.convs1 = nn.ModuleList([
26
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
27
- padding=get_padding(kernel_size, dilation[0]))),
28
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
29
- padding=get_padding(kernel_size, dilation[1]))),
30
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
31
- padding=get_padding(kernel_size, dilation[2])))
32
- ])
33
- self.convs1.apply(init_weights)
34
-
35
- self.convs2 = nn.ModuleList([
36
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
37
- padding=get_padding(kernel_size, 1))),
38
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
39
- padding=get_padding(kernel_size, 1))),
40
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
41
- padding=get_padding(kernel_size, 1)))
42
- ])
43
- self.convs2.apply(init_weights)
44
-
45
- self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
46
- if self.h.get("use_cuda_kernel", False):
47
- from indextts.BigVGAN.alias_free_activation.cuda.activation1d import Activation1d
48
- else:
49
- from indextts.BigVGAN.alias_free_torch import Activation1d
50
- if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
51
- self.activations = nn.ModuleList([
52
- Activation1d(
53
- activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
54
- for _ in range(self.num_layers)
55
- ])
56
- elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
57
- self.activations = nn.ModuleList([
58
- Activation1d(
59
- activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
60
- for _ in range(self.num_layers)
61
- ])
62
- else:
63
- raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
64
-
65
- def forward(self, x):
66
- acts1, acts2 = self.activations[::2], self.activations[1::2]
67
- for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
68
- xt = a1(x)
69
- xt = c1(xt)
70
- xt = a2(xt)
71
- xt = c2(xt)
72
- x = xt + x
73
-
74
- return x
75
-
76
- def remove_weight_norm(self):
77
- for l in self.convs1:
78
- remove_weight_norm(l)
79
- for l in self.convs2:
80
- remove_weight_norm(l)
81
-
82
-
83
- class AMPBlock2(torch.nn.Module):
84
- def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None):
85
- super(AMPBlock2, self).__init__()
86
- self.h = h
87
-
88
- self.convs = nn.ModuleList([
89
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
90
- padding=get_padding(kernel_size, dilation[0]))),
91
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
92
- padding=get_padding(kernel_size, dilation[1])))
93
- ])
94
- self.convs.apply(init_weights)
95
-
96
- self.num_layers = len(self.convs) # total number of conv layers
97
- if self.h.get("use_cuda_kernel", False):
98
- from indextts.BigVGAN.alias_free_activation.cuda.activation1d import Activation1d
99
- else:
100
- from indextts.BigVGAN.alias_free_torch import Activation1d
101
-
102
- if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
103
- self.activations = nn.ModuleList([
104
- Activation1d(
105
- activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
106
- for _ in range(self.num_layers)
107
- ])
108
- elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
109
- self.activations = nn.ModuleList([
110
- Activation1d(
111
- activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
112
- for _ in range(self.num_layers)
113
- ])
114
- else:
115
- raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
116
-
117
- def forward(self, x):
118
- for c, a in zip(self.convs, self.activations):
119
- xt = a(x)
120
- xt = c(xt)
121
- x = xt + x
122
-
123
- return x
124
-
125
- def remove_weight_norm(self):
126
- for l in self.convs:
127
- remove_weight_norm(l)
128
-
129
-
130
- class BigVGAN(torch.nn.Module):
131
- # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
132
- def __init__(self, h, use_cuda_kernel=False):
133
- """
134
- Args:
135
- h (dict)
136
- use_cuda_kernel (bool): whether to use custom cuda kernel for anti-aliased activation
137
- """
138
- super(BigVGAN, self).__init__()
139
- self.h = h
140
- self.h["use_cuda_kernel"] = use_cuda_kernel
141
-
142
- self.num_kernels = len(h.resblock_kernel_sizes)
143
- self.num_upsamples = len(h.upsample_rates)
144
-
145
- self.feat_upsample = h.feat_upsample
146
- self.cond_in_each_up_layer = h.cond_d_vector_in_each_upsampling_layer
147
-
148
- # pre conv
149
- self.conv_pre = weight_norm(Conv1d(h.gpt_dim, h.upsample_initial_channel, 7, 1, padding=3))
150
-
151
- # define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
152
- resblock = AMPBlock1 if h.resblock == "1" else AMPBlock2
153
-
154
- # transposed conv-based upsamplers. does not apply anti-aliasing
155
- self.ups = nn.ModuleList()
156
- for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
157
- self.ups.append(nn.ModuleList([
158
- weight_norm(ConvTranspose1d(h.upsample_initial_channel // (2 ** i),
159
- h.upsample_initial_channel // (2 ** (i + 1)),
160
- k, u, padding=(k - u) // 2))
161
- ]))
162
-
163
- # residual blocks using anti-aliased multi-periodicity composition modules (AMP)
164
- self.resblocks = nn.ModuleList()
165
- for i in range(len(self.ups)):
166
- ch = h.upsample_initial_channel // (2 ** (i + 1))
167
- for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
168
- self.resblocks.append(resblock(self.h, ch, k, d, activation=h.activation))
169
- if use_cuda_kernel:
170
- from indextts.BigVGAN.alias_free_activation.cuda.activation1d import Activation1d
171
- else:
172
- from indextts.BigVGAN.alias_free_torch import Activation1d
173
-
174
- # post conv
175
- if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
176
- activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale)
177
- self.activation_post = Activation1d(activation=activation_post)
178
- elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
179
- activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
180
- self.activation_post = Activation1d(activation=activation_post)
181
- else:
182
- raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
183
-
184
- self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
185
-
186
- # weight initialization
187
- for i in range(len(self.ups)):
188
- self.ups[i].apply(init_weights)
189
- self.conv_post.apply(init_weights)
190
-
191
- self.speaker_encoder = ECAPA_TDNN(h.num_mels, lin_neurons=h.speaker_embedding_dim)
192
- self.cond_layer = nn.Conv1d(h.speaker_embedding_dim, h.upsample_initial_channel, 1)
193
- if self.cond_in_each_up_layer:
194
- self.conds = nn.ModuleList()
195
- for i in range(len(self.ups)):
196
- ch = h.upsample_initial_channel // (2 ** (i + 1))
197
- self.conds.append(nn.Conv1d(h.speaker_embedding_dim, ch, 1))
198
-
199
- # self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
200
-
201
- def forward(self, x, mel_ref, lens=None):
202
- speaker_embedding = self.speaker_encoder(mel_ref, lens)
203
- n_batch = x.size(0)
204
- contrastive_loss = None
205
- if n_batch * 2 == speaker_embedding.size(0):
206
- spe_emb_chunk1, spe_emb_chunk2 = speaker_embedding[:n_batch, :, :], speaker_embedding[n_batch:, :, :]
207
- contrastive_loss = self.cal_clip_loss(spe_emb_chunk1.squeeze(1), spe_emb_chunk2.squeeze(1), self.logit_scale.exp())
208
-
209
- speaker_embedding = speaker_embedding[:n_batch, :, :]
210
- speaker_embedding = speaker_embedding.transpose(1, 2)
211
-
212
- # upsample feat
213
- if self.feat_upsample:
214
- x = torch.nn.functional.interpolate(
215
- x.transpose(1, 2),
216
- scale_factor=[4],
217
- mode="linear",
218
- ).squeeze(1)
219
- else:
220
- x = x.transpose(1, 2)
221
-
222
- ### bigVGAN ###
223
- # pre conv
224
- x = self.conv_pre(x)
225
-
226
- x = x + self.cond_layer(speaker_embedding)
227
-
228
- for i in range(self.num_upsamples):
229
- # upsampling
230
- for i_up in range(len(self.ups[i])):
231
- x = self.ups[i][i_up](x)
232
-
233
- if self.cond_in_each_up_layer:
234
- x = x + self.conds[i](speaker_embedding)
235
-
236
- # AMP blocks
237
- xs = None
238
- for j in range(self.num_kernels):
239
- if xs is None:
240
- xs = self.resblocks[i * self.num_kernels + j](x)
241
- else:
242
- xs += self.resblocks[i * self.num_kernels + j](x)
243
- x = xs / self.num_kernels
244
-
245
- # post conv
246
- x = self.activation_post(x)
247
- x = self.conv_post(x)
248
- x = torch.tanh(x)
249
-
250
- return x, contrastive_loss
251
-
252
- def remove_weight_norm(self):
253
- print('Removing weight norm...')
254
- for l in self.ups:
255
- for l_i in l:
256
- remove_weight_norm(l_i)
257
- for l in self.resblocks:
258
- l.remove_weight_norm()
259
- remove_weight_norm(self.conv_pre)
260
- remove_weight_norm(self.conv_post)
261
-
262
- def cal_clip_loss(self, image_features, text_features, logit_scale):
263
- device = image_features.device
264
- logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
265
- labels = torch.arange(logits_per_image.shape[0], device=device, dtype=torch.long)
266
- total_loss = (
267
- F.cross_entropy(logits_per_image, labels) +
268
- F.cross_entropy(logits_per_text, labels)
269
- ) / 2
270
- return total_loss
271
-
272
- def get_logits(self, image_features, text_features, logit_scale):
273
- logits_per_image = logit_scale * image_features @ text_features.T
274
- logits_per_text = logit_scale * text_features @ image_features.T
275
- return logits_per_image, logits_per_text
276
-
277
-
278
- class DiscriminatorP(torch.nn.Module):
279
- def __init__(self, h, period, kernel_size=5, stride=3, use_spectral_norm=False):
280
- super(DiscriminatorP, self).__init__()
281
- self.period = period
282
- self.d_mult = h.discriminator_channel_mult
283
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
284
- self.convs = nn.ModuleList([
285
- norm_f(Conv2d(1, int(32 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
286
- norm_f(Conv2d(int(32 * self.d_mult), int(128 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
287
- norm_f(Conv2d(int(128 * self.d_mult), int(512 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
288
- norm_f(Conv2d(int(512 * self.d_mult), int(1024 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
289
- norm_f(Conv2d(int(1024 * self.d_mult), int(1024 * self.d_mult), (kernel_size, 1), 1, padding=(2, 0))),
290
- ])
291
- self.conv_post = norm_f(Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
292
-
293
- def forward(self, x):
294
- fmap = []
295
-
296
- # 1d to 2d
297
- b, c, t = x.shape
298
- if t % self.period != 0: # pad first
299
- n_pad = self.period - (t % self.period)
300
- x = F.pad(x, (0, n_pad), "reflect")
301
- t = t + n_pad
302
- x = x.view(b, c, t // self.period, self.period)
303
-
304
- for l in self.convs:
305
- x = l(x)
306
- x = F.leaky_relu(x, LRELU_SLOPE)
307
- fmap.append(x)
308
- x = self.conv_post(x)
309
- fmap.append(x)
310
- x = torch.flatten(x, 1, -1)
311
-
312
- return x, fmap
313
-
314
-
315
- class MultiPeriodDiscriminator(torch.nn.Module):
316
- def __init__(self, h):
317
- super(MultiPeriodDiscriminator, self).__init__()
318
- self.mpd_reshapes = h.mpd_reshapes
319
- print("mpd_reshapes: {}".format(self.mpd_reshapes))
320
- discriminators = [DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm) for rs in self.mpd_reshapes]
321
- self.discriminators = nn.ModuleList(discriminators)
322
-
323
- def forward(self, y, y_hat):
324
- y_d_rs = []
325
- y_d_gs = []
326
- fmap_rs = []
327
- fmap_gs = []
328
- for i, d in enumerate(self.discriminators):
329
- y_d_r, fmap_r = d(y)
330
- y_d_g, fmap_g = d(y_hat)
331
- y_d_rs.append(y_d_r)
332
- fmap_rs.append(fmap_r)
333
- y_d_gs.append(y_d_g)
334
- fmap_gs.append(fmap_g)
335
-
336
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
337
-
338
-
339
- class DiscriminatorR(nn.Module):
340
- def __init__(self, cfg, resolution):
341
- super().__init__()
342
-
343
- self.resolution = resolution
344
- assert len(self.resolution) == 3, \
345
- "MRD layer requires list with len=3, got {}".format(self.resolution)
346
- self.lrelu_slope = LRELU_SLOPE
347
-
348
- norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm
349
- if hasattr(cfg, "mrd_use_spectral_norm"):
350
- print("INFO: overriding MRD use_spectral_norm as {}".format(cfg.mrd_use_spectral_norm))
351
- norm_f = weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
352
- self.d_mult = cfg.discriminator_channel_mult
353
- if hasattr(cfg, "mrd_channel_mult"):
354
- print("INFO: overriding mrd channel multiplier as {}".format(cfg.mrd_channel_mult))
355
- self.d_mult = cfg.mrd_channel_mult
356
-
357
- self.convs = nn.ModuleList([
358
- norm_f(nn.Conv2d(1, int(32 * self.d_mult), (3, 9), padding=(1, 4))),
359
- norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
360
- norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
361
- norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
362
- norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 3), padding=(1, 1))),
363
- ])
364
- self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)))
365
-
366
- def forward(self, x):
367
- fmap = []
368
-
369
- x = self.spectrogram(x)
370
- x = x.unsqueeze(1)
371
- for l in self.convs:
372
- x = l(x)
373
- x = F.leaky_relu(x, self.lrelu_slope)
374
- fmap.append(x)
375
- x = self.conv_post(x)
376
- fmap.append(x)
377
- x = torch.flatten(x, 1, -1)
378
-
379
- return x, fmap
380
-
381
- def spectrogram(self, x):
382
- n_fft, hop_length, win_length = self.resolution
383
- x = F.pad(x, (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), mode='reflect')
384
- x = x.squeeze(1)
385
- x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, return_complex=True)
386
- x = torch.view_as_real(x) # [B, F, TT, 2]
387
- mag = torch.norm(x, p=2, dim=-1) # [B, F, TT]
388
-
389
- return mag
390
-
391
-
392
- class MultiResolutionDiscriminator(nn.Module):
393
- def __init__(self, cfg, debug=False):
394
- super().__init__()
395
- self.resolutions = cfg.resolutions
396
- assert len(self.resolutions) == 3, \
397
- "MRD requires list of list with len=3, each element having a list with len=3. got {}".\
398
- format(self.resolutions)
399
- self.discriminators = nn.ModuleList(
400
- [DiscriminatorR(cfg, resolution) for resolution in self.resolutions]
401
- )
402
-
403
- def forward(self, y, y_hat):
404
- y_d_rs = []
405
- y_d_gs = []
406
- fmap_rs = []
407
- fmap_gs = []
408
-
409
- for i, d in enumerate(self.discriminators):
410
- y_d_r, fmap_r = d(x=y)
411
- y_d_g, fmap_g = d(x=y_hat)
412
- y_d_rs.append(y_d_r)
413
- fmap_rs.append(fmap_r)
414
- y_d_gs.append(y_d_g)
415
- fmap_gs.append(fmap_g)
416
-
417
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
418
-
419
-
420
- def feature_loss(fmap_r, fmap_g):
421
- loss = 0
422
- for dr, dg in zip(fmap_r, fmap_g):
423
- for rl, gl in zip(dr, dg):
424
- loss += torch.mean(torch.abs(rl - gl))
425
-
426
- return loss * 2
427
-
428
-
429
- def discriminator_loss(disc_real_outputs, disc_generated_outputs):
430
- loss = 0
431
- r_losses = []
432
- g_losses = []
433
- for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
434
- r_loss = torch.mean((1 - dr)**2)
435
- g_loss = torch.mean(dg**2)
436
- loss += (r_loss + g_loss)
437
- r_losses.append(r_loss.item())
438
- g_losses.append(g_loss.item())
439
-
440
- return loss, r_losses, g_losses
441
-
442
-
443
- def generator_loss(disc_outputs):
444
- loss = 0
445
- gen_losses = []
446
- for dg in disc_outputs:
447
- l = torch.mean((1 - dg)**2)
448
- gen_losses.append(l)
449
- loss += l
450
-
451
- return loss, gen_losses
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/BigVGAN/nnet/CNN.py DELETED
@@ -1,546 +0,0 @@
1
- """Library implementing convolutional neural networks.
2
-
3
- Authors
4
- * Mirco Ravanelli 2020
5
- * Jianyuan Zhong 2020
6
- * Cem Subakan 2021
7
- * Davide Borra 2021
8
- * Andreas Nautsch 2022
9
- * Sarthak Yadav 2022
10
- """
11
-
12
- import logging
13
- import math
14
- from typing import Tuple
15
-
16
- import numpy as np
17
- import torch
18
- import torch.nn as nn
19
- import torch.nn.functional as F
20
- import torchaudio
21
-
22
-
23
- class SincConv(nn.Module):
24
- """This function implements SincConv (SincNet).
25
-
26
- M. Ravanelli, Y. Bengio, "Speaker Recognition from raw waveform with
27
- SincNet", in Proc. of SLT 2018 (https://arxiv.org/abs/1808.00158)
28
-
29
- Arguments
30
- ---------
31
- out_channels : int
32
- It is the number of output channels.
33
- kernel_size: int
34
- Kernel size of the convolutional filters.
35
- input_shape : tuple
36
- The shape of the input. Alternatively use ``in_channels``.
37
- in_channels : int
38
- The number of input channels. Alternatively use ``input_shape``.
39
- stride : int
40
- Stride factor of the convolutional filters. When the stride factor > 1,
41
- a decimation in time is performed.
42
- dilation : int
43
- Dilation factor of the convolutional filters.
44
- padding : str
45
- (same, valid, causal). If "valid", no padding is performed.
46
- If "same" and stride is 1, output shape is the same as the input shape.
47
- "causal" results in causal (dilated) convolutions.
48
- padding_mode : str
49
- This flag specifies the type of padding. See torch.nn documentation
50
- for more information.
51
- sample_rate : int
52
- Sampling rate of the input signals. It is only used for sinc_conv.
53
- min_low_hz : float
54
- Lowest possible frequency (in Hz) for a filter. It is only used for
55
- sinc_conv.
56
- min_band_hz : float
57
- Lowest possible value (in Hz) for a filter bandwidth.
58
-
59
- Example
60
- -------
61
- >>> inp_tensor = torch.rand([10, 16000])
62
- >>> conv = SincConv(input_shape=inp_tensor.shape, out_channels=25, kernel_size=11)
63
- >>> out_tensor = conv(inp_tensor)
64
- >>> out_tensor.shape
65
- torch.Size([10, 16000, 25])
66
- """
67
-
68
- def __init__(
69
- self,
70
- out_channels,
71
- kernel_size,
72
- input_shape=None,
73
- in_channels=None,
74
- stride=1,
75
- dilation=1,
76
- padding="same",
77
- padding_mode="reflect",
78
- sample_rate=16000,
79
- min_low_hz=50,
80
- min_band_hz=50,
81
- ):
82
- super().__init__()
83
- self.in_channels = in_channels
84
- self.out_channels = out_channels
85
- self.kernel_size = kernel_size
86
- self.stride = stride
87
- self.dilation = dilation
88
- self.padding = padding
89
- self.padding_mode = padding_mode
90
- self.sample_rate = sample_rate
91
- self.min_low_hz = min_low_hz
92
- self.min_band_hz = min_band_hz
93
-
94
- # input shape inference
95
- if input_shape is None and self.in_channels is None:
96
- raise ValueError("Must provide one of input_shape or in_channels")
97
-
98
- if self.in_channels is None:
99
- self.in_channels = self._check_input_shape(input_shape)
100
-
101
- if self.out_channels % self.in_channels != 0:
102
- raise ValueError(
103
- "Number of output channels must be divisible by in_channels"
104
- )
105
-
106
- # Initialize Sinc filters
107
- self._init_sinc_conv()
108
-
109
- def forward(self, x):
110
- """Returns the output of the convolution.
111
-
112
- Arguments
113
- ---------
114
- x : torch.Tensor (batch, time, channel)
115
- input to convolve. 2d or 4d tensors are expected.
116
-
117
- Returns
118
- -------
119
- wx : torch.Tensor
120
- The convolved outputs.
121
- """
122
- x = x.transpose(1, -1)
123
- self.device = x.device
124
-
125
- unsqueeze = x.ndim == 2
126
- if unsqueeze:
127
- x = x.unsqueeze(1)
128
-
129
- if self.padding == "same":
130
- x = self._manage_padding(
131
- x, self.kernel_size, self.dilation, self.stride
132
- )
133
-
134
- elif self.padding == "causal":
135
- num_pad = (self.kernel_size - 1) * self.dilation
136
- x = F.pad(x, (num_pad, 0))
137
-
138
- elif self.padding == "valid":
139
- pass
140
-
141
- else:
142
- raise ValueError(
143
- "Padding must be 'same', 'valid' or 'causal'. Got %s."
144
- % (self.padding)
145
- )
146
-
147
- sinc_filters = self._get_sinc_filters()
148
-
149
- wx = F.conv1d(
150
- x,
151
- sinc_filters,
152
- stride=self.stride,
153
- padding=0,
154
- dilation=self.dilation,
155
- groups=self.in_channels,
156
- )
157
-
158
- if unsqueeze:
159
- wx = wx.squeeze(1)
160
-
161
- wx = wx.transpose(1, -1)
162
-
163
- return wx
164
-
165
- def _check_input_shape(self, shape):
166
- """Checks the input shape and returns the number of input channels."""
167
-
168
- if len(shape) == 2:
169
- in_channels = 1
170
- elif len(shape) == 3:
171
- in_channels = shape[-1]
172
- else:
173
- raise ValueError(
174
- "sincconv expects 2d or 3d inputs. Got " + str(len(shape))
175
- )
176
-
177
- # Kernel size must be odd
178
- if self.kernel_size % 2 == 0:
179
- raise ValueError(
180
- "The field kernel size must be an odd number. Got %s."
181
- % (self.kernel_size)
182
- )
183
- return in_channels
184
-
185
- def _get_sinc_filters(self):
186
- """This functions creates the sinc-filters to used for sinc-conv."""
187
- # Computing the low frequencies of the filters
188
- low = self.min_low_hz + torch.abs(self.low_hz_)
189
-
190
- # Setting minimum band and minimum freq
191
- high = torch.clamp(
192
- low + self.min_band_hz + torch.abs(self.band_hz_),
193
- self.min_low_hz,
194
- self.sample_rate / 2,
195
- )
196
- band = (high - low)[:, 0]
197
-
198
- # Passing from n_ to the corresponding f_times_t domain
199
- self.n_ = self.n_.to(self.device)
200
- self.window_ = self.window_.to(self.device)
201
- f_times_t_low = torch.matmul(low, self.n_)
202
- f_times_t_high = torch.matmul(high, self.n_)
203
-
204
- # Left part of the filters.
205
- band_pass_left = (
206
- (torch.sin(f_times_t_high) - torch.sin(f_times_t_low))
207
- / (self.n_ / 2)
208
- ) * self.window_
209
-
210
- # Central element of the filter
211
- band_pass_center = 2 * band.view(-1, 1)
212
-
213
- # Right part of the filter (sinc filters are symmetric)
214
- band_pass_right = torch.flip(band_pass_left, dims=[1])
215
-
216
- # Combining left, central, and right part of the filter
217
- band_pass = torch.cat(
218
- [band_pass_left, band_pass_center, band_pass_right], dim=1
219
- )
220
-
221
- # Amplitude normalization
222
- band_pass = band_pass / (2 * band[:, None])
223
-
224
- # Setting up the filter coefficients
225
- filters = band_pass.view(self.out_channels, 1, self.kernel_size)
226
-
227
- return filters
228
-
229
- def _init_sinc_conv(self):
230
- """Initializes the parameters of the sinc_conv layer."""
231
-
232
- # Initialize filterbanks such that they are equally spaced in Mel scale
233
- high_hz = self.sample_rate / 2 - (self.min_low_hz + self.min_band_hz)
234
-
235
- mel = torch.linspace(
236
- self._to_mel(self.min_low_hz),
237
- self._to_mel(high_hz),
238
- self.out_channels + 1,
239
- )
240
-
241
- hz = self._to_hz(mel)
242
-
243
- # Filter lower frequency and bands
244
- self.low_hz_ = hz[:-1].unsqueeze(1)
245
- self.band_hz_ = (hz[1:] - hz[:-1]).unsqueeze(1)
246
-
247
- # Maiking freq and bands learnable
248
- self.low_hz_ = nn.Parameter(self.low_hz_)
249
- self.band_hz_ = nn.Parameter(self.band_hz_)
250
-
251
- # Hamming window
252
- n_lin = torch.linspace(
253
- 0, (self.kernel_size / 2) - 1, steps=int((self.kernel_size / 2))
254
- )
255
- self.window_ = 0.54 - 0.46 * torch.cos(
256
- 2 * math.pi * n_lin / self.kernel_size
257
- )
258
-
259
- # Time axis (only half is needed due to symmetry)
260
- n = (self.kernel_size - 1) / 2.0
261
- self.n_ = (
262
- 2 * math.pi * torch.arange(-n, 0).view(1, -1) / self.sample_rate
263
- )
264
-
265
- def _to_mel(self, hz):
266
- """Converts frequency in Hz to the mel scale."""
267
- return 2595 * np.log10(1 + hz / 700)
268
-
269
- def _to_hz(self, mel):
270
- """Converts frequency in the mel scale to Hz."""
271
- return 700 * (10 ** (mel / 2595) - 1)
272
-
273
- def _manage_padding(self, x, kernel_size: int, dilation: int, stride: int):
274
- """This function performs zero-padding on the time axis
275
- such that their lengths is unchanged after the convolution.
276
-
277
- Arguments
278
- ---------
279
- x : torch.Tensor
280
- Input tensor.
281
- kernel_size : int
282
- Size of kernel.
283
- dilation : int
284
- Dilation used.
285
- stride : int
286
- Stride.
287
-
288
- Returns
289
- -------
290
- x : torch.Tensor
291
- """
292
-
293
- # Detecting input shape
294
- L_in = self.in_channels
295
-
296
- # Time padding
297
- padding = get_padding_elem(L_in, stride, kernel_size, dilation)
298
-
299
- # Applying padding
300
- x = F.pad(x, padding, mode=self.padding_mode)
301
-
302
- return x
303
-
304
-
305
- class Conv1d(nn.Module):
306
- """This function implements 1d convolution.
307
-
308
- Arguments
309
- ---------
310
- out_channels : int
311
- It is the number of output channels.
312
- kernel_size : int
313
- Kernel size of the convolutional filters.
314
- input_shape : tuple
315
- The shape of the input. Alternatively use ``in_channels``.
316
- in_channels : int
317
- The number of input channels. Alternatively use ``input_shape``.
318
- stride : int
319
- Stride factor of the convolutional filters. When the stride factor > 1,
320
- a decimation in time is performed.
321
- dilation : int
322
- Dilation factor of the convolutional filters.
323
- padding : str
324
- (same, valid, causal). If "valid", no padding is performed.
325
- If "same" and stride is 1, output shape is the same as the input shape.
326
- "causal" results in causal (dilated) convolutions.
327
- groups : int
328
- Number of blocked connections from input channels to output channels.
329
- bias : bool
330
- Whether to add a bias term to convolution operation.
331
- padding_mode : str
332
- This flag specifies the type of padding. See torch.nn documentation
333
- for more information.
334
- skip_transpose : bool
335
- If False, uses batch x time x channel convention of speechbrain.
336
- If True, uses batch x channel x time convention.
337
- weight_norm : bool
338
- If True, use weight normalization,
339
- to be removed with self.remove_weight_norm() at inference
340
- conv_init : str
341
- Weight initialization for the convolution network
342
- default_padding: str or int
343
- This sets the default padding mode that will be used by the pytorch Conv1d backend.
344
-
345
- Example
346
- -------
347
- >>> inp_tensor = torch.rand([10, 40, 16])
348
- >>> cnn_1d = Conv1d(
349
- ... input_shape=inp_tensor.shape, out_channels=8, kernel_size=5
350
- ... )
351
- >>> out_tensor = cnn_1d(inp_tensor)
352
- >>> out_tensor.shape
353
- torch.Size([10, 40, 8])
354
- """
355
-
356
- def __init__(
357
- self,
358
- out_channels,
359
- kernel_size,
360
- input_shape=None,
361
- in_channels=None,
362
- stride=1,
363
- dilation=1,
364
- padding="same",
365
- groups=1,
366
- bias=True,
367
- padding_mode="reflect",
368
- skip_transpose=False,
369
- weight_norm=False,
370
- conv_init=None,
371
- default_padding=0,
372
- ):
373
- super().__init__()
374
- self.kernel_size = kernel_size
375
- self.stride = stride
376
- self.dilation = dilation
377
- self.padding = padding
378
- self.padding_mode = padding_mode
379
- self.unsqueeze = False
380
- self.skip_transpose = skip_transpose
381
-
382
- if input_shape is None and in_channels is None:
383
- raise ValueError("Must provide one of input_shape or in_channels")
384
-
385
- if in_channels is None:
386
- in_channels = self._check_input_shape(input_shape)
387
-
388
- self.in_channels = in_channels
389
-
390
- self.conv = nn.Conv1d(
391
- in_channels,
392
- out_channels,
393
- self.kernel_size,
394
- stride=self.stride,
395
- dilation=self.dilation,
396
- padding=default_padding,
397
- groups=groups,
398
- bias=bias,
399
- )
400
-
401
- if conv_init == "kaiming":
402
- nn.init.kaiming_normal_(self.conv.weight)
403
- elif conv_init == "zero":
404
- nn.init.zeros_(self.conv.weight)
405
- elif conv_init == "normal":
406
- nn.init.normal_(self.conv.weight, std=1e-6)
407
-
408
- if weight_norm:
409
- self.conv = nn.utils.weight_norm(self.conv)
410
-
411
- def forward(self, x):
412
- """Returns the output of the convolution.
413
-
414
- Arguments
415
- ---------
416
- x : torch.Tensor (batch, time, channel)
417
- input to convolve. 2d or 4d tensors are expected.
418
-
419
- Returns
420
- -------
421
- wx : torch.Tensor
422
- The convolved outputs.
423
- """
424
- if not self.skip_transpose:
425
- x = x.transpose(1, -1)
426
-
427
- if self.unsqueeze:
428
- x = x.unsqueeze(1)
429
-
430
- if self.padding == "same":
431
- x = self._manage_padding(
432
- x, self.kernel_size, self.dilation, self.stride
433
- )
434
-
435
- elif self.padding == "causal":
436
- num_pad = (self.kernel_size - 1) * self.dilation
437
- x = F.pad(x, (num_pad, 0))
438
-
439
- elif self.padding == "valid":
440
- pass
441
-
442
- else:
443
- raise ValueError(
444
- "Padding must be 'same', 'valid' or 'causal'. Got "
445
- + self.padding
446
- )
447
-
448
- wx = self.conv(x)
449
-
450
- if self.unsqueeze:
451
- wx = wx.squeeze(1)
452
-
453
- if not self.skip_transpose:
454
- wx = wx.transpose(1, -1)
455
-
456
- return wx
457
-
458
- def _manage_padding(self, x, kernel_size: int, dilation: int, stride: int):
459
- """This function performs zero-padding on the time axis
460
- such that their lengths is unchanged after the convolution.
461
-
462
- Arguments
463
- ---------
464
- x : torch.Tensor
465
- Input tensor.
466
- kernel_size : int
467
- Size of kernel.
468
- dilation : int
469
- Dilation used.
470
- stride : int
471
- Stride.
472
-
473
- Returns
474
- -------
475
- x : torch.Tensor
476
- The padded outputs.
477
- """
478
-
479
- # Detecting input shape
480
- L_in = self.in_channels
481
-
482
- # Time padding
483
- padding = get_padding_elem(L_in, stride, kernel_size, dilation)
484
-
485
- # Applying padding
486
- x = F.pad(x, padding, mode=self.padding_mode)
487
-
488
- return x
489
-
490
- def _check_input_shape(self, shape):
491
- """Checks the input shape and returns the number of input channels."""
492
-
493
- if len(shape) == 2:
494
- self.unsqueeze = True
495
- in_channels = 1
496
- elif self.skip_transpose:
497
- in_channels = shape[1]
498
- elif len(shape) == 3:
499
- in_channels = shape[2]
500
- else:
501
- raise ValueError(
502
- "conv1d expects 2d, 3d inputs. Got " + str(len(shape))
503
- )
504
-
505
- # Kernel size must be odd
506
- if not self.padding == "valid" and self.kernel_size % 2 == 0:
507
- raise ValueError(
508
- "The field kernel size must be an odd number. Got %s."
509
- % (self.kernel_size)
510
- )
511
-
512
- return in_channels
513
-
514
- def remove_weight_norm(self):
515
- """Removes weight normalization at inference if used during training."""
516
- self.conv = nn.utils.remove_weight_norm(self.conv)
517
-
518
-
519
- def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
520
- """This function computes the number of elements to add for zero-padding.
521
-
522
- Arguments
523
- ---------
524
- L_in : int
525
- stride: int
526
- kernel_size : int
527
- dilation : int
528
-
529
- Returns
530
- -------
531
- padding : int
532
- The size of the padding to be added
533
- """
534
- if stride > 1:
535
- padding = [math.floor(kernel_size / 2), math.floor(kernel_size / 2)]
536
-
537
- else:
538
- L_out = (
539
- math.floor((L_in - dilation * (kernel_size - 1) - 1) / stride) + 1
540
- )
541
- padding = [
542
- math.floor((L_in - L_out) / 2),
543
- math.floor((L_in - L_out) / 2),
544
- ]
545
- return padding
546
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/BigVGAN/nnet/__init__.py DELETED
File without changes
indextts/BigVGAN/nnet/linear.py DELETED
@@ -1,89 +0,0 @@
1
- """Library implementing linear transformation.
2
-
3
- Authors
4
- * Mirco Ravanelli 2020
5
- * Davide Borra 2021
6
- """
7
-
8
- import logging
9
-
10
- import torch
11
- import torch.nn as nn
12
-
13
-
14
- class Linear(torch.nn.Module):
15
- """Computes a linear transformation y = wx + b.
16
-
17
- Arguments
18
- ---------
19
- n_neurons : int
20
- It is the number of output neurons (i.e, the dimensionality of the
21
- output).
22
- input_shape : tuple
23
- It is the shape of the input tensor.
24
- input_size : int
25
- Size of the input tensor.
26
- bias : bool
27
- If True, the additive bias b is adopted.
28
- max_norm : float
29
- weight max-norm.
30
- combine_dims : bool
31
- If True and the input is 4D, combine 3rd and 4th dimensions of input.
32
-
33
- Example
34
- -------
35
- >>> inputs = torch.rand(10, 50, 40)
36
- >>> lin_t = Linear(input_shape=(10, 50, 40), n_neurons=100)
37
- >>> output = lin_t(inputs)
38
- >>> output.shape
39
- torch.Size([10, 50, 100])
40
- """
41
-
42
- def __init__(
43
- self,
44
- n_neurons,
45
- input_shape=None,
46
- input_size=None,
47
- bias=True,
48
- max_norm=None,
49
- combine_dims=False,
50
- ):
51
- super().__init__()
52
- self.max_norm = max_norm
53
- self.combine_dims = combine_dims
54
-
55
- if input_shape is None and input_size is None:
56
- raise ValueError("Expected one of input_shape or input_size")
57
-
58
- if input_size is None:
59
- input_size = input_shape[-1]
60
- if len(input_shape) == 4 and self.combine_dims:
61
- input_size = input_shape[2] * input_shape[3]
62
-
63
- # Weights are initialized following pytorch approach
64
- self.w = nn.Linear(input_size, n_neurons, bias=bias)
65
-
66
- def forward(self, x):
67
- """Returns the linear transformation of input tensor.
68
-
69
- Arguments
70
- ---------
71
- x : torch.Tensor
72
- Input to transform linearly.
73
-
74
- Returns
75
- -------
76
- wx : torch.Tensor
77
- The linearly transformed outputs.
78
- """
79
- if x.ndim == 4 and self.combine_dims:
80
- x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])
81
-
82
- if self.max_norm is not None:
83
- self.w.weight.data = torch.renorm(
84
- self.w.weight.data, p=2, dim=0, maxnorm=self.max_norm
85
- )
86
-
87
- wx = self.w(x)
88
-
89
- return wx
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/BigVGAN/nnet/normalization.py DELETED
@@ -1,670 +0,0 @@
1
- """Library implementing normalization.
2
-
3
- Authors
4
- * Mirco Ravanelli 2020
5
- * Guillermo Cámbara 2021
6
- * Sarthak Yadav 2022
7
- """
8
-
9
- import torch
10
- import torch.nn as nn
11
-
12
-
13
- class BatchNorm1d(nn.Module):
14
- """Applies 1d batch normalization to the input tensor.
15
-
16
- Arguments
17
- ---------
18
- input_shape : tuple
19
- The expected shape of the input. Alternatively, use ``input_size``.
20
- input_size : int
21
- The expected size of the input. Alternatively, use ``input_shape``.
22
- eps : float
23
- This value is added to std deviation estimation to improve the numerical
24
- stability.
25
- momentum : float
26
- It is a value used for the running_mean and running_var computation.
27
- affine : bool
28
- When set to True, the affine parameters are learned.
29
- track_running_stats : bool
30
- When set to True, this module tracks the running mean and variance,
31
- and when set to False, this module does not track such statistics.
32
- combine_batch_time : bool
33
- When true, it combines batch an time axis.
34
- skip_transpose : bool
35
- Whether to skip the transposition.
36
-
37
-
38
- Example
39
- -------
40
- >>> input = torch.randn(100, 10)
41
- >>> norm = BatchNorm1d(input_shape=input.shape)
42
- >>> output = norm(input)
43
- >>> output.shape
44
- torch.Size([100, 10])
45
- """
46
-
47
- def __init__(
48
- self,
49
- input_shape=None,
50
- input_size=None,
51
- eps=1e-05,
52
- momentum=0.1,
53
- affine=True,
54
- track_running_stats=True,
55
- combine_batch_time=False,
56
- skip_transpose=False,
57
- ):
58
- super().__init__()
59
- self.combine_batch_time = combine_batch_time
60
- self.skip_transpose = skip_transpose
61
-
62
- if input_size is None and skip_transpose:
63
- input_size = input_shape[1]
64
- elif input_size is None:
65
- input_size = input_shape[-1]
66
-
67
- self.norm = nn.BatchNorm1d(
68
- input_size,
69
- eps=eps,
70
- momentum=momentum,
71
- affine=affine,
72
- track_running_stats=track_running_stats,
73
- )
74
-
75
- def forward(self, x):
76
- """Returns the normalized input tensor.
77
-
78
- Arguments
79
- ---------
80
- x : torch.Tensor (batch, time, [channels])
81
- input to normalize. 2d or 3d tensors are expected in input
82
- 4d tensors can be used when combine_dims=True.
83
-
84
- Returns
85
- -------
86
- x_n : torch.Tensor
87
- The normalized outputs.
88
- """
89
- shape_or = x.shape
90
- if self.combine_batch_time:
91
- if x.ndim == 3:
92
- x = x.reshape(shape_or[0] * shape_or[1], shape_or[2])
93
- else:
94
- x = x.reshape(
95
- shape_or[0] * shape_or[1], shape_or[3], shape_or[2]
96
- )
97
-
98
- elif not self.skip_transpose:
99
- x = x.transpose(-1, 1)
100
-
101
- x_n = self.norm(x)
102
-
103
- if self.combine_batch_time:
104
- x_n = x_n.reshape(shape_or)
105
- elif not self.skip_transpose:
106
- x_n = x_n.transpose(1, -1)
107
-
108
- return x_n
109
-
110
-
111
- class BatchNorm2d(nn.Module):
112
- """Applies 2d batch normalization to the input tensor.
113
-
114
- Arguments
115
- ---------
116
- input_shape : tuple
117
- The expected shape of the input. Alternatively, use ``input_size``.
118
- input_size : int
119
- The expected size of the input. Alternatively, use ``input_shape``.
120
- eps : float
121
- This value is added to std deviation estimation to improve the numerical
122
- stability.
123
- momentum : float
124
- It is a value used for the running_mean and running_var computation.
125
- affine : bool
126
- When set to True, the affine parameters are learned.
127
- track_running_stats : bool
128
- When set to True, this module tracks the running mean and variance,
129
- and when set to False, this module does not track such statistics.
130
-
131
- Example
132
- -------
133
- >>> input = torch.randn(100, 10, 5, 20)
134
- >>> norm = BatchNorm2d(input_shape=input.shape)
135
- >>> output = norm(input)
136
- >>> output.shape
137
- torch.Size([100, 10, 5, 20])
138
- """
139
-
140
- def __init__(
141
- self,
142
- input_shape=None,
143
- input_size=None,
144
- eps=1e-05,
145
- momentum=0.1,
146
- affine=True,
147
- track_running_stats=True,
148
- ):
149
- super().__init__()
150
-
151
- if input_shape is None and input_size is None:
152
- raise ValueError("Expected input_shape or input_size as input")
153
-
154
- if input_size is None:
155
- input_size = input_shape[-1]
156
-
157
- self.norm = nn.BatchNorm2d(
158
- input_size,
159
- eps=eps,
160
- momentum=momentum,
161
- affine=affine,
162
- track_running_stats=track_running_stats,
163
- )
164
-
165
- def forward(self, x):
166
- """Returns the normalized input tensor.
167
-
168
- Arguments
169
- ---------
170
- x : torch.Tensor (batch, time, channel1, channel2)
171
- input to normalize. 4d tensors are expected.
172
-
173
- Returns
174
- -------
175
- x_n : torch.Tensor
176
- The normalized outputs.
177
- """
178
- x = x.transpose(-1, 1)
179
- x_n = self.norm(x)
180
- x_n = x_n.transpose(1, -1)
181
-
182
- return x_n
183
-
184
-
185
- class LayerNorm(nn.Module):
186
- """Applies layer normalization to the input tensor.
187
-
188
- Arguments
189
- ---------
190
- input_size : int
191
- The expected size of the dimension to be normalized.
192
- input_shape : tuple
193
- The expected shape of the input.
194
- eps : float
195
- This value is added to std deviation estimation to improve the numerical
196
- stability.
197
- elementwise_affine : bool
198
- If True, this module has learnable per-element affine parameters
199
- initialized to ones (for weights) and zeros (for biases).
200
-
201
- Example
202
- -------
203
- >>> input = torch.randn(100, 101, 128)
204
- >>> norm = LayerNorm(input_shape=input.shape)
205
- >>> output = norm(input)
206
- >>> output.shape
207
- torch.Size([100, 101, 128])
208
- """
209
-
210
- def __init__(
211
- self,
212
- input_size=None,
213
- input_shape=None,
214
- eps=1e-05,
215
- elementwise_affine=True,
216
- ):
217
- super().__init__()
218
- self.eps = eps
219
- self.elementwise_affine = elementwise_affine
220
-
221
- if input_shape is not None:
222
- input_size = input_shape[2:]
223
-
224
- self.norm = torch.nn.LayerNorm(
225
- input_size,
226
- eps=self.eps,
227
- elementwise_affine=self.elementwise_affine,
228
- )
229
-
230
- def forward(self, x):
231
- """Returns the normalized input tensor.
232
-
233
- Arguments
234
- ---------
235
- x : torch.Tensor (batch, time, channels)
236
- input to normalize. 3d or 4d tensors are expected.
237
-
238
- Returns
239
- -------
240
- The normalized outputs.
241
- """
242
- return self.norm(x)
243
-
244
-
245
- class InstanceNorm1d(nn.Module):
246
- """Applies 1d instance normalization to the input tensor.
247
-
248
- Arguments
249
- ---------
250
- input_shape : tuple
251
- The expected shape of the input. Alternatively, use ``input_size``.
252
- input_size : int
253
- The expected size of the input. Alternatively, use ``input_shape``.
254
- eps : float
255
- This value is added to std deviation estimation to improve the numerical
256
- stability.
257
- momentum : float
258
- It is a value used for the running_mean and running_var computation.
259
- track_running_stats : bool
260
- When set to True, this module tracks the running mean and variance,
261
- and when set to False, this module does not track such statistics.
262
- affine : bool
263
- A boolean value that when set to True, this module has learnable
264
- affine parameters, initialized the same way as done for
265
- batch normalization. Default: False.
266
-
267
- Example
268
- -------
269
- >>> input = torch.randn(100, 10, 20)
270
- >>> norm = InstanceNorm1d(input_shape=input.shape)
271
- >>> output = norm(input)
272
- >>> output.shape
273
- torch.Size([100, 10, 20])
274
- """
275
-
276
- def __init__(
277
- self,
278
- input_shape=None,
279
- input_size=None,
280
- eps=1e-05,
281
- momentum=0.1,
282
- track_running_stats=True,
283
- affine=False,
284
- ):
285
- super().__init__()
286
-
287
- if input_shape is None and input_size is None:
288
- raise ValueError("Expected input_shape or input_size as input")
289
-
290
- if input_size is None:
291
- input_size = input_shape[-1]
292
-
293
- self.norm = nn.InstanceNorm1d(
294
- input_size,
295
- eps=eps,
296
- momentum=momentum,
297
- track_running_stats=track_running_stats,
298
- affine=affine,
299
- )
300
-
301
- def forward(self, x):
302
- """Returns the normalized input tensor.
303
-
304
- Arguments
305
- ---------
306
- x : torch.Tensor (batch, time, channels)
307
- input to normalize. 3d tensors are expected.
308
-
309
- Returns
310
- -------
311
- x_n : torch.Tensor
312
- The normalized outputs.
313
- """
314
- x = x.transpose(-1, 1)
315
- x_n = self.norm(x)
316
- x_n = x_n.transpose(1, -1)
317
-
318
- return x_n
319
-
320
-
321
- class InstanceNorm2d(nn.Module):
322
- """Applies 2d instance normalization to the input tensor.
323
-
324
- Arguments
325
- ---------
326
- input_shape : tuple
327
- The expected shape of the input. Alternatively, use ``input_size``.
328
- input_size : int
329
- The expected size of the input. Alternatively, use ``input_shape``.
330
- eps : float
331
- This value is added to std deviation estimation to improve the numerical
332
- stability.
333
- momentum : float
334
- It is a value used for the running_mean and running_var computation.
335
- track_running_stats : bool
336
- When set to True, this module tracks the running mean and variance,
337
- and when set to False, this module does not track such statistics.
338
- affine : bool
339
- A boolean value that when set to True, this module has learnable
340
- affine parameters, initialized the same way as done for
341
- batch normalization. Default: False.
342
-
343
- Example
344
- -------
345
- >>> input = torch.randn(100, 10, 20, 2)
346
- >>> norm = InstanceNorm2d(input_shape=input.shape)
347
- >>> output = norm(input)
348
- >>> output.shape
349
- torch.Size([100, 10, 20, 2])
350
- """
351
-
352
- def __init__(
353
- self,
354
- input_shape=None,
355
- input_size=None,
356
- eps=1e-05,
357
- momentum=0.1,
358
- track_running_stats=True,
359
- affine=False,
360
- ):
361
- super().__init__()
362
-
363
- if input_shape is None and input_size is None:
364
- raise ValueError("Expected input_shape or input_size as input")
365
-
366
- if input_size is None:
367
- input_size = input_shape[-1]
368
-
369
- self.norm = nn.InstanceNorm2d(
370
- input_size,
371
- eps=eps,
372
- momentum=momentum,
373
- track_running_stats=track_running_stats,
374
- affine=affine,
375
- )
376
-
377
- def forward(self, x):
378
- """Returns the normalized input tensor.
379
-
380
- Arguments
381
- ---------
382
- x : torch.Tensor (batch, time, channel1, channel2)
383
- input to normalize. 4d tensors are expected.
384
-
385
- Returns
386
- -------
387
- x_n : torch.Tensor
388
- The normalized outputs.
389
- """
390
- x = x.transpose(-1, 1)
391
- x_n = self.norm(x)
392
- x_n = x_n.transpose(1, -1)
393
-
394
- return x_n
395
-
396
-
397
- class GroupNorm(nn.Module):
398
- """Applies group normalization to the input tensor.
399
-
400
- Arguments
401
- ---------
402
- input_shape : tuple
403
- The expected shape of the input. Alternatively, use ``input_size``.
404
- input_size : int
405
- The expected size of the input. Alternatively, use ``input_shape``.
406
- num_groups : int
407
- Number of groups to separate the channels into.
408
- eps : float
409
- This value is added to std deviation estimation to improve the numerical
410
- stability.
411
- affine : bool
412
- A boolean value that when set to True, this module has learnable per-channel
413
- affine parameters initialized to ones (for weights) and zeros (for biases).
414
-
415
- Example
416
- -------
417
- >>> input = torch.randn(100, 101, 128)
418
- >>> norm = GroupNorm(input_size=128, num_groups=128)
419
- >>> output = norm(input)
420
- >>> output.shape
421
- torch.Size([100, 101, 128])
422
- """
423
-
424
- def __init__(
425
- self,
426
- input_shape=None,
427
- input_size=None,
428
- num_groups=None,
429
- eps=1e-05,
430
- affine=True,
431
- ):
432
- super().__init__()
433
- self.eps = eps
434
- self.affine = affine
435
-
436
- if input_shape is None and input_size is None:
437
- raise ValueError("Expected input_shape or input_size as input")
438
-
439
- if num_groups is None:
440
- raise ValueError("Expected num_groups as input")
441
-
442
- if input_shape is not None:
443
- input_size = input_shape[-1]
444
-
445
- self.norm = torch.nn.GroupNorm(
446
- num_groups,
447
- input_size,
448
- eps=self.eps,
449
- affine=self.affine,
450
- )
451
-
452
- def forward(self, x):
453
- """Returns the normalized input tensor.
454
-
455
- Arguments
456
- ---------
457
- x : torch.Tensor (batch, time, channels)
458
- input to normalize. 3d or 4d tensors are expected.
459
-
460
- Returns
461
- -------
462
- x_n : torch.Tensor
463
- The normalized outputs.
464
- """
465
- x = x.transpose(-1, 1)
466
- x_n = self.norm(x)
467
- x_n = x_n.transpose(1, -1)
468
-
469
- return x_n
470
-
471
-
472
- class ExponentialMovingAverage(nn.Module):
473
- """
474
- Applies learnable exponential moving average, as required by learnable PCEN layer
475
-
476
- Arguments
477
- ---------
478
- input_size : int
479
- The expected size of the input.
480
- coeff_init: float
481
- Initial smoothing coefficient value
482
- per_channel: bool
483
- Controls whether every smoothing coefficients are learned
484
- independently for every input channel
485
- trainable: bool
486
- whether to learn the PCEN parameters or use fixed
487
- skip_transpose : bool
488
- If False, uses batch x time x channel convention of speechbrain.
489
- If True, uses batch x channel x time convention.
490
-
491
- Example
492
- -------
493
- >>> inp_tensor = torch.rand([10, 50, 40])
494
- >>> pcen = ExponentialMovingAverage(40)
495
- >>> out_tensor = pcen(inp_tensor)
496
- >>> out_tensor.shape
497
- torch.Size([10, 50, 40])
498
- """
499
-
500
- def __init__(
501
- self,
502
- input_size: int,
503
- coeff_init: float = 0.04,
504
- per_channel: bool = False,
505
- trainable: bool = True,
506
- skip_transpose: bool = False,
507
- ):
508
- super().__init__()
509
- self._coeff_init = coeff_init
510
- self._per_channel = per_channel
511
- self.skip_transpose = skip_transpose
512
- self.trainable = trainable
513
- weights = (
514
- torch.ones(
515
- input_size,
516
- )
517
- if self._per_channel
518
- else torch.ones(
519
- 1,
520
- )
521
- )
522
- self._weights = nn.Parameter(
523
- weights * self._coeff_init, requires_grad=trainable
524
- )
525
-
526
- def forward(self, x):
527
- """Returns the normalized input tensor.
528
-
529
- Arguments
530
- ---------
531
- x : torch.Tensor (batch, time, channels)
532
- input to normalize.
533
- """
534
- if not self.skip_transpose:
535
- x = x.transpose(1, -1)
536
- w = torch.clamp(self._weights, min=0.0, max=1.0)
537
- initial_state = x[:, :, 0]
538
-
539
- def scan(init_state, x, w):
540
- """Loops and accumulates."""
541
- x = x.permute(2, 0, 1)
542
- acc = init_state
543
- results = []
544
- for ix in range(x.shape[0]):
545
- acc = (w * x[ix]) + ((1.0 - w) * acc)
546
- results.append(acc.unsqueeze(0))
547
- results = torch.cat(results, dim=0)
548
- results = results.permute(1, 2, 0)
549
- return results
550
-
551
- output = scan(initial_state, x, w)
552
- if not self.skip_transpose:
553
- output = output.transpose(1, -1)
554
- return output
555
-
556
-
557
- class PCEN(nn.Module):
558
- """
559
- This class implements a learnable Per-channel energy normalization (PCEN) layer, supporting both
560
- original PCEN as specified in [1] as well as sPCEN as specified in [2]
561
-
562
- [1] Yuxuan Wang, Pascal Getreuer, Thad Hughes, Richard F. Lyon, Rif A. Saurous, "Trainable Frontend For
563
- Robust and Far-Field Keyword Spotting", in Proc of ICASSP 2017 (https://arxiv.org/abs/1607.05666)
564
-
565
- [2] Neil Zeghidour, Olivier Teboul, F{\'e}lix de Chaumont Quitry & Marco Tagliasacchi, "LEAF: A LEARNABLE FRONTEND
566
- FOR AUDIO CLASSIFICATION", in Proc of ICLR 2021 (https://arxiv.org/abs/2101.08596)
567
-
568
- The default argument values correspond with those used by [2].
569
-
570
- Arguments
571
- ---------
572
- input_size : int
573
- The expected size of the input.
574
- alpha: float
575
- specifies alpha coefficient for PCEN
576
- smooth_coef: float
577
- specified smooth coefficient for PCEN
578
- delta: float
579
- specifies delta coefficient for PCEN
580
- root: float
581
- specifies root coefficient for PCEN
582
- floor: float
583
- specifies floor coefficient for PCEN
584
- trainable: bool
585
- whether to learn the PCEN parameters or use fixed
586
- per_channel_smooth_coef: bool
587
- whether to learn independent smooth coefficients for every channel.
588
- when True, essentially using sPCEN from [2]
589
- skip_transpose : bool
590
- If False, uses batch x time x channel convention of speechbrain.
591
- If True, uses batch x channel x time convention.
592
-
593
- Example
594
- -------
595
- >>> inp_tensor = torch.rand([10, 50, 40])
596
- >>> pcen = PCEN(40, alpha=0.96) # sPCEN
597
- >>> out_tensor = pcen(inp_tensor)
598
- >>> out_tensor.shape
599
- torch.Size([10, 50, 40])
600
- """
601
-
602
- def __init__(
603
- self,
604
- input_size,
605
- alpha: float = 0.96,
606
- smooth_coef: float = 0.04,
607
- delta: float = 2.0,
608
- root: float = 2.0,
609
- floor: float = 1e-12,
610
- trainable: bool = True,
611
- per_channel_smooth_coef: bool = True,
612
- skip_transpose: bool = False,
613
- ):
614
- super().__init__()
615
- self._smooth_coef = smooth_coef
616
- self._floor = floor
617
- self._per_channel_smooth_coef = per_channel_smooth_coef
618
- self.skip_transpose = skip_transpose
619
- self.alpha = nn.Parameter(
620
- torch.ones(input_size) * alpha, requires_grad=trainable
621
- )
622
- self.delta = nn.Parameter(
623
- torch.ones(input_size) * delta, requires_grad=trainable
624
- )
625
- self.root = nn.Parameter(
626
- torch.ones(input_size) * root, requires_grad=trainable
627
- )
628
-
629
- self.ema = ExponentialMovingAverage(
630
- input_size,
631
- coeff_init=self._smooth_coef,
632
- per_channel=self._per_channel_smooth_coef,
633
- skip_transpose=True,
634
- trainable=trainable,
635
- )
636
-
637
- def forward(self, x):
638
- """Returns the normalized input tensor.
639
-
640
- Arguments
641
- ---------
642
- x : torch.Tensor (batch, time, channels)
643
- input to normalize.
644
-
645
- Returns
646
- -------
647
- output : torch.Tensor
648
- The normalized outputs.
649
- """
650
- if not self.skip_transpose:
651
- x = x.transpose(1, -1)
652
- alpha = torch.min(
653
- self.alpha, torch.tensor(1.0, dtype=x.dtype, device=x.device)
654
- )
655
- root = torch.max(
656
- self.root, torch.tensor(1.0, dtype=x.dtype, device=x.device)
657
- )
658
- ema_smoother = self.ema(x)
659
- one_over_root = 1.0 / root
660
- output = (
661
- x / (self._floor + ema_smoother) ** alpha.view(1, -1, 1)
662
- + self.delta.view(1, -1, 1)
663
- ) ** one_over_root.view(1, -1, 1) - self.delta.view(
664
- 1, -1, 1
665
- ) ** one_over_root.view(
666
- 1, -1, 1
667
- )
668
- if not self.skip_transpose:
669
- output = output.transpose(1, -1)
670
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/BigVGAN/utils.py DELETED
@@ -1,101 +0,0 @@
1
- # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
2
- # LICENSE is in incl_licenses directory.
3
-
4
- import glob
5
- import os
6
-
7
- import matplotlib
8
- import matplotlib.pylab as plt
9
- import torch
10
- from scipy.io.wavfile import write
11
- from torch.nn.utils import weight_norm
12
-
13
- matplotlib.use("Agg")
14
-
15
- MAX_WAV_VALUE = 32768.0
16
-
17
-
18
- def plot_spectrogram(spectrogram):
19
- fig, ax = plt.subplots(figsize=(10, 2))
20
- im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
21
- plt.colorbar(im, ax=ax)
22
-
23
- fig.canvas.draw()
24
- plt.close()
25
-
26
- return fig
27
-
28
-
29
- def plot_spectrogram_clipped(spectrogram, clip_max=2.0):
30
- fig, ax = plt.subplots(figsize=(10, 2))
31
- im = ax.imshow(
32
- spectrogram,
33
- aspect="auto",
34
- origin="lower",
35
- interpolation="none",
36
- vmin=1e-6,
37
- vmax=clip_max,
38
- )
39
- plt.colorbar(im, ax=ax)
40
-
41
- fig.canvas.draw()
42
- plt.close()
43
-
44
- return fig
45
-
46
-
47
- def init_weights(m, mean=0.0, std=0.01):
48
- classname = m.__class__.__name__
49
- if classname.find("Conv") != -1:
50
- m.weight.data.normal_(mean, std)
51
-
52
-
53
- def apply_weight_norm(m):
54
- classname = m.__class__.__name__
55
- if classname.find("Conv") != -1:
56
- weight_norm(m)
57
-
58
-
59
- def get_padding(kernel_size, dilation=1):
60
- return int((kernel_size * dilation - dilation) / 2)
61
-
62
-
63
- def load_checkpoint(filepath, device):
64
- assert os.path.isfile(filepath)
65
- print(f"Loading '{filepath}'")
66
- checkpoint_dict = torch.load(filepath, map_location=device)
67
- print("Complete.")
68
- return checkpoint_dict
69
-
70
-
71
- def save_checkpoint(filepath, obj):
72
- print(f"Saving checkpoint to {filepath}")
73
- torch.save(obj, filepath)
74
- print("Complete.")
75
-
76
-
77
- def scan_checkpoint(cp_dir, prefix, renamed_file=None):
78
- # Fallback to original scanning logic first
79
- pattern = os.path.join(cp_dir, prefix + "????????")
80
- cp_list = glob.glob(pattern)
81
-
82
- if len(cp_list) > 0:
83
- last_checkpoint_path = sorted(cp_list)[-1]
84
- print(f"[INFO] Resuming from checkpoint: '{last_checkpoint_path}'")
85
- return last_checkpoint_path
86
-
87
- # If no pattern-based checkpoints are found, check for renamed file
88
- if renamed_file:
89
- renamed_path = os.path.join(cp_dir, renamed_file)
90
- if os.path.isfile(renamed_path):
91
- print(f"[INFO] Resuming from renamed checkpoint: '{renamed_file}'")
92
- return renamed_path
93
-
94
- return None
95
-
96
-
97
- def save_audio(audio, path, sr):
98
- # wav: torch with 1d shape
99
- audio = audio * MAX_WAV_VALUE
100
- audio = audio.cpu().numpy().astype("int16")
101
- write(path, sr, audio)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/__init__.py DELETED
File without changes
indextts/cli.py DELETED
@@ -1,65 +0,0 @@
1
- import os
2
- import sys
3
- import warnings
4
- # Suppress warnings from tensorflow and other libraries
5
- warnings.filterwarnings("ignore", category=UserWarning)
6
- warnings.filterwarnings("ignore", category=FutureWarning)
7
- def main():
8
- import argparse
9
- parser = argparse.ArgumentParser(description="IndexTTS Command Line")
10
- parser.add_argument("text", type=str, help="Text to be synthesized")
11
- parser.add_argument("-v", "--voice", type=str, required=True, help="Path to the audio prompt file (wav format)")
12
- parser.add_argument("-o", "--output_path", type=str, default="gen.wav", help="Path to the output wav file")
13
- parser.add_argument("-c", "--config", type=str, default="checkpoints/config.yaml", help="Path to the config file. Default is 'checkpoints/config.yaml'")
14
- parser.add_argument("--model_dir", type=str, default="checkpoints", help="Path to the model directory. Default is 'checkpoints'")
15
- parser.add_argument("--fp16", action="store_true", default=False, help="Use FP16 for inference if available")
16
- parser.add_argument("-f", "--force", action="store_true", default=False, help="Force to overwrite the output file if it exists")
17
- parser.add_argument("-d", "--device", type=str, default=None, help="Device to run the model on (cpu, cuda, mps, xpu)." )
18
- args = parser.parse_args()
19
- if len(args.text.strip()) == 0:
20
- print("ERROR: Text is empty.")
21
- parser.print_help()
22
- sys.exit(1)
23
- if not os.path.exists(args.voice):
24
- print(f"Audio prompt file {args.voice} does not exist.")
25
- parser.print_help()
26
- sys.exit(1)
27
- if not os.path.exists(args.config):
28
- print(f"Config file {args.config} does not exist.")
29
- parser.print_help()
30
- sys.exit(1)
31
-
32
- output_path = args.output_path
33
- if os.path.exists(output_path):
34
- if not args.force:
35
- print(f"ERROR: Output file {output_path} already exists. Use --force to overwrite.")
36
- parser.print_help()
37
- sys.exit(1)
38
- else:
39
- os.remove(output_path)
40
-
41
- try:
42
- import torch
43
- except ImportError:
44
- print("ERROR: PyTorch is not installed. Please install it first.")
45
- sys.exit(1)
46
-
47
- if args.device is None:
48
- if torch.cuda.is_available():
49
- args.device = "cuda:0"
50
- elif hasattr(torch, "xpu") and torch.xpu.is_available():
51
- args.device = "xpu"
52
- elif hasattr(torch, "mps") and torch.mps.is_available():
53
- args.device = "mps"
54
- else:
55
- args.device = "cpu"
56
- args.fp16 = False # Disable FP16 on CPU
57
- print("WARNING: Running on CPU may be slow.")
58
-
59
- # TODO: Add CLI support for IndexTTS2.
60
- from indextts.infer import IndexTTS
61
- tts = IndexTTS(cfg_path=args.config, model_dir=args.model_dir, use_fp16=args.fp16, device=args.device)
62
- tts.infer(audio_prompt=args.voice, text=args.text.strip(), output_path=output_path)
63
-
64
- if __name__ == "__main__":
65
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/gpt/__init__.py DELETED
File without changes
indextts/gpt/conformer/__init__.py DELETED
File without changes
indextts/gpt/conformer/attention.py DELETED
@@ -1,312 +0,0 @@
1
- # Copyright (c) 2019 Shigeki Karita
2
- # 2020 Mobvoi Inc (Binbin Zhang)
3
- # 2022 Xingchen Song ([email protected])
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
-
17
- """Multi-Head Attention layer definition."""
18
-
19
- import math
20
- from typing import Tuple
21
-
22
- import torch
23
- from torch import nn
24
-
25
-
26
- class MultiHeadedAttention(nn.Module):
27
- """Multi-Head Attention layer.
28
-
29
- Args:
30
- n_head (int): The number of heads.
31
- n_feat (int): The number of features.
32
- dropout_rate (float): Dropout rate.
33
-
34
- """
35
- def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
36
- """Construct an MultiHeadedAttention object."""
37
- super().__init__()
38
- assert n_feat % n_head == 0
39
- # We assume d_v always equals d_k
40
- self.d_k = n_feat // n_head
41
- self.h = n_head
42
- self.linear_q = nn.Linear(n_feat, n_feat)
43
- self.linear_k = nn.Linear(n_feat, n_feat)
44
- self.linear_v = nn.Linear(n_feat, n_feat)
45
- self.linear_out = nn.Linear(n_feat, n_feat)
46
- self.dropout = nn.Dropout(p=dropout_rate)
47
-
48
- def forward_qkv(
49
- self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
50
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
51
- """Transform query, key and value.
52
-
53
- Args:
54
- query (torch.Tensor): Query tensor (#batch, time1, size).
55
- key (torch.Tensor): Key tensor (#batch, time2, size).
56
- value (torch.Tensor): Value tensor (#batch, time2, size).
57
-
58
- Returns:
59
- torch.Tensor: Transformed query tensor, size
60
- (#batch, n_head, time1, d_k).
61
- torch.Tensor: Transformed key tensor, size
62
- (#batch, n_head, time2, d_k).
63
- torch.Tensor: Transformed value tensor, size
64
- (#batch, n_head, time2, d_k).
65
-
66
- """
67
- n_batch = query.size(0)
68
- q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
69
- k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
70
- v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
71
- q = q.transpose(1, 2) # (batch, head, time1, d_k)
72
- k = k.transpose(1, 2) # (batch, head, time2, d_k)
73
- v = v.transpose(1, 2) # (batch, head, time2, d_k)
74
-
75
- return q, k, v
76
-
77
- def forward_attention(
78
- self, value: torch.Tensor, scores: torch.Tensor,
79
- mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
80
- ) -> torch.Tensor:
81
- """Compute attention context vector.
82
-
83
- Args:
84
- value (torch.Tensor): Transformed value, size
85
- (#batch, n_head, time2, d_k).
86
- scores (torch.Tensor): Attention score, size
87
- (#batch, n_head, time1, time2).
88
- mask (torch.Tensor): Mask, size (#batch, 1, time2) or
89
- (#batch, time1, time2), (0, 0, 0) means fake mask.
90
-
91
- Returns:
92
- torch.Tensor: Transformed value (#batch, time1, d_model)
93
- weighted by the attention score (#batch, time1, time2).
94
-
95
- """
96
- n_batch = value.size(0)
97
- # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
98
- # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
99
- # 1st chunk to ease the onnx export.]
100
- # 2. pytorch training
101
- if mask.size(2) > 0 : # time2 > 0
102
- mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
103
- # For last chunk, time2 might be larger than scores.size(-1)
104
- mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
105
- scores = scores.masked_fill(mask, -float('inf'))
106
- attn = torch.softmax(scores, dim=-1).masked_fill(
107
- mask, 0.0) # (batch, head, time1, time2)
108
- # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
109
- # 1. onnx(16/-1, -1/-1, 16/0)
110
- # 2. jit (16/-1, -1/-1, 16/0, 16/4)
111
- else:
112
- attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
113
-
114
- p_attn = self.dropout(attn)
115
- x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
116
- x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
117
- self.h * self.d_k)
118
- ) # (batch, time1, d_model)
119
-
120
- return self.linear_out(x) # (batch, time1, d_model)
121
-
122
- def forward(self, query: torch.Tensor, key: torch.Tensor,
123
- value: torch.Tensor,
124
- mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
125
- pos_emb: torch.Tensor = torch.empty(0),
126
- cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
127
- ) -> Tuple[torch.Tensor, torch.Tensor]:
128
- """Compute scaled dot product attention.
129
-
130
- Args:
131
- query (torch.Tensor): Query tensor (#batch, time1, size).
132
- key (torch.Tensor): Key tensor (#batch, time2, size).
133
- value (torch.Tensor): Value tensor (#batch, time2, size).
134
- mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
135
- (#batch, time1, time2).
136
- 1.When applying cross attention between decoder and encoder,
137
- the batch padding mask for input is in (#batch, 1, T) shape.
138
- 2.When applying self attention of encoder,
139
- the mask is in (#batch, T, T) shape.
140
- 3.When applying self attention of decoder,
141
- the mask is in (#batch, L, L) shape.
142
- 4.If the different position in decoder see different block
143
- of the encoder, such as Mocha, the passed in mask could be
144
- in (#batch, L, T) shape. But there is no such case in current
145
- Wenet.
146
- cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
147
- where `cache_t == chunk_size * num_decoding_left_chunks`
148
- and `head * d_k == size`
149
-
150
-
151
- Returns:
152
- torch.Tensor: Output tensor (#batch, time1, d_model).
153
- torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
154
- where `cache_t == chunk_size * num_decoding_left_chunks`
155
- and `head * d_k == size`
156
-
157
- """
158
- q, k, v = self.forward_qkv(query, key, value)
159
-
160
- # NOTE(xcsong):
161
- # when export onnx model, for 1st chunk, we feed
162
- # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
163
- # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
164
- # In all modes, `if cache.size(0) > 0` will alwayse be `True`
165
- # and we will always do splitting and
166
- # concatnation(this will simplify onnx export). Note that
167
- # it's OK to concat & split zero-shaped tensors(see code below).
168
- # when export jit model, for 1st chunk, we always feed
169
- # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
170
- # >>> a = torch.ones((1, 2, 0, 4))
171
- # >>> b = torch.ones((1, 2, 3, 4))
172
- # >>> c = torch.cat((a, b), dim=2)
173
- # >>> torch.equal(b, c) # True
174
- # >>> d = torch.split(a, 2, dim=-1)
175
- # >>> torch.equal(d[0], d[1]) # True
176
- if cache.size(0) > 0:
177
- key_cache, value_cache = torch.split(
178
- cache, cache.size(-1) // 2, dim=-1)
179
- k = torch.cat([key_cache, k], dim=2)
180
- v = torch.cat([value_cache, v], dim=2)
181
- # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
182
- # non-trivial to calculate `next_cache_start` here.
183
- new_cache = torch.cat((k, v), dim=-1)
184
-
185
- scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
186
- return self.forward_attention(v, scores, mask), new_cache
187
-
188
-
189
- class RelPositionMultiHeadedAttention(MultiHeadedAttention):
190
- """Multi-Head Attention layer with relative position encoding.
191
- Paper: https://arxiv.org/abs/1901.02860
192
- Args:
193
- n_head (int): The number of heads.
194
- n_feat (int): The number of features.
195
- dropout_rate (float): Dropout rate.
196
- """
197
- def __init__(self, n_head, n_feat, dropout_rate):
198
- """Construct an RelPositionMultiHeadedAttention object."""
199
- super().__init__(n_head, n_feat, dropout_rate)
200
- # linear transformation for positional encoding
201
- self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
202
- # these two learnable bias are used in matrix c and matrix d
203
- # as described in https://arxiv.org/abs/1901.02860 Section 3.3
204
- self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
205
- self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
206
- torch.nn.init.xavier_uniform_(self.pos_bias_u)
207
- torch.nn.init.xavier_uniform_(self.pos_bias_v)
208
-
209
- def rel_shift(self, x, zero_triu: bool = False):
210
- """Compute relative positinal encoding.
211
- Args:
212
- x (torch.Tensor): Input tensor (batch, time, size).
213
- zero_triu (bool): If true, return the lower triangular part of
214
- the matrix.
215
- Returns:
216
- torch.Tensor: Output tensor.
217
- """
218
-
219
- zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
220
- device=x.device,
221
- dtype=x.dtype)
222
- x_padded = torch.cat([zero_pad, x], dim=-1)
223
-
224
- x_padded = x_padded.view(x.size()[0],
225
- x.size()[1],
226
- x.size(3) + 1, x.size(2))
227
- x = x_padded[:, :, 1:].view_as(x)
228
-
229
- if zero_triu:
230
- ones = torch.ones((x.size(2), x.size(3)))
231
- x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
232
-
233
- return x
234
-
235
- def forward(self, query: torch.Tensor,
236
- key: torch.Tensor, value: torch.Tensor,
237
- mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
238
- pos_emb: torch.Tensor = torch.empty(0),
239
- cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
240
- ) -> Tuple[torch.Tensor, torch.Tensor]:
241
- """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
242
- Args:
243
- query (torch.Tensor): Query tensor (#batch, time1, size).
244
- key (torch.Tensor): Key tensor (#batch, time2, size).
245
- value (torch.Tensor): Value tensor (#batch, time2, size).
246
- mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
247
- (#batch, time1, time2), (0, 0, 0) means fake mask.
248
- pos_emb (torch.Tensor): Positional embedding tensor
249
- (#batch, time2, size).
250
- cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
251
- where `cache_t == chunk_size * num_decoding_left_chunks`
252
- and `head * d_k == size`
253
- Returns:
254
- torch.Tensor: Output tensor (#batch, time1, d_model).
255
- torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
256
- where `cache_t == chunk_size * num_decoding_left_chunks`
257
- and `head * d_k == size`
258
- """
259
- q, k, v = self.forward_qkv(query, key, value)
260
- q = q.transpose(1, 2) # (batch, time1, head, d_k)
261
-
262
- # NOTE(xcsong):
263
- # when export onnx model, for 1st chunk, we feed
264
- # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
265
- # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
266
- # In all modes, `if cache.size(0) > 0` will alwayse be `True`
267
- # and we will always do splitting and
268
- # concatnation(this will simplify onnx export). Note that
269
- # it's OK to concat & split zero-shaped tensors(see code below).
270
- # when export jit model, for 1st chunk, we always feed
271
- # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
272
- # >>> a = torch.ones((1, 2, 0, 4))
273
- # >>> b = torch.ones((1, 2, 3, 4))
274
- # >>> c = torch.cat((a, b), dim=2)
275
- # >>> torch.equal(b, c) # True
276
- # >>> d = torch.split(a, 2, dim=-1)
277
- # >>> torch.equal(d[0], d[1]) # True
278
- if cache.size(0) > 0:
279
- key_cache, value_cache = torch.split(
280
- cache, cache.size(-1) // 2, dim=-1)
281
- k = torch.cat([key_cache, k], dim=2)
282
- v = torch.cat([value_cache, v], dim=2)
283
- # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
284
- # non-trivial to calculate `next_cache_start` here.
285
- new_cache = torch.cat((k, v), dim=-1)
286
-
287
- n_batch_pos = pos_emb.size(0)
288
- p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
289
- p = p.transpose(1, 2) # (batch, head, time1, d_k)
290
-
291
- # (batch, head, time1, d_k)
292
- q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
293
- # (batch, head, time1, d_k)
294
- q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
295
-
296
- # compute attention score
297
- # first compute matrix a and matrix c
298
- # as described in https://arxiv.org/abs/1901.02860 Section 3.3
299
- # (batch, head, time1, time2)
300
- matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
301
-
302
- # compute matrix b and matrix d
303
- # (batch, head, time1, time2)
304
- matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
305
- # Remove rel_shift since it is useless in speech recognition,
306
- # and it requires special attention for streaming.
307
- # matrix_bd = self.rel_shift(matrix_bd)
308
-
309
- scores = (matrix_ac + matrix_bd) / math.sqrt(
310
- self.d_k) # (batch, head, time1, time2)
311
-
312
- return self.forward_attention(v, scores, mask), new_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/gpt/conformer/embedding.py DELETED
@@ -1,163 +0,0 @@
1
- # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # Modified from ESPnet(https://github.com/espnet/espnet)
15
-
16
- """Positonal Encoding Module."""
17
-
18
- import math
19
- from typing import Tuple, Union
20
-
21
- import torch
22
- import torch.nn.functional as F
23
-
24
-
25
- class PositionalEncoding(torch.nn.Module):
26
- """Positional encoding.
27
-
28
- :param int d_model: embedding dim
29
- :param float dropout_rate: dropout rate
30
- :param int max_len: maximum input length
31
-
32
- PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
33
- PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
34
- """
35
- def __init__(self,
36
- d_model: int,
37
- dropout_rate: float,
38
- max_len: int = 5000,
39
- reverse: bool = False):
40
- """Construct an PositionalEncoding object."""
41
- super().__init__()
42
- self.d_model = d_model
43
- self.xscale = math.sqrt(self.d_model)
44
- self.dropout = torch.nn.Dropout(p=dropout_rate)
45
- self.max_len = max_len
46
-
47
- pe = torch.zeros(self.max_len, self.d_model)
48
- position = torch.arange(0, self.max_len).unsqueeze(1)
49
- div_term = torch.exp(
50
- torch.arange(0, self.d_model, 2) *
51
- -(math.log(10000.0) / self.d_model))
52
- pe[:, 0::2] = torch.sin(position * div_term)
53
- pe[:, 1::2] = torch.cos(position * div_term)
54
- pe = pe.unsqueeze(0)
55
- self.register_buffer('pe', pe)
56
-
57
- def forward(self,
58
- x: torch.Tensor,
59
- offset: Union[int, torch.Tensor] = 0) \
60
- -> Tuple[torch.Tensor, torch.Tensor]:
61
- """Add positional encoding.
62
-
63
- Args:
64
- x (torch.Tensor): Input. Its shape is (batch, time, ...)
65
- offset (int, torch.tensor): position offset
66
-
67
- Returns:
68
- torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
69
- torch.Tensor: for compatibility to RelPositionalEncoding
70
- """
71
-
72
- self.pe = self.pe.to(x.device)
73
- pos_emb = self.position_encoding(offset, x.size(1), False)
74
- x = x * self.xscale + pos_emb
75
- return self.dropout(x), self.dropout(pos_emb)
76
-
77
- def position_encoding(self, offset: Union[int, torch.Tensor], size: int,
78
- apply_dropout: bool = True) -> torch.Tensor:
79
- """ For getting encoding in a streaming fashion
80
-
81
- Attention!!!!!
82
- we apply dropout only once at the whole utterance level in a none
83
- streaming way, but will call this function several times with
84
- increasing input size in a streaming scenario, so the dropout will
85
- be applied several times.
86
-
87
- Args:
88
- offset (int or torch.tensor): start offset
89
- size (int): required size of position encoding
90
-
91
- Returns:
92
- torch.Tensor: Corresponding encoding
93
- """
94
- # How to subscript a Union type:
95
- # https://github.com/pytorch/pytorch/issues/69434
96
- if isinstance(offset, int):
97
- assert offset + size < self.max_len
98
- pos_emb = self.pe[:, offset:offset + size]
99
- elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
100
- assert offset + size < self.max_len
101
- pos_emb = self.pe[:, offset:offset + size]
102
- else: # for batched streaming decoding on GPU
103
- assert torch.max(offset) + size < self.max_len
104
- index = offset.unsqueeze(1) + \
105
- torch.arange(0, size).to(offset.device) # B X T
106
- flag = index > 0
107
- # remove negative offset
108
- index = index * flag
109
- pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model
110
-
111
- if apply_dropout:
112
- pos_emb = self.dropout(pos_emb)
113
- return pos_emb
114
-
115
- class RelPositionalEncoding(PositionalEncoding):
116
- """Relative positional encoding module.
117
- See : Appendix B in https://arxiv.org/abs/1901.02860
118
- Args:
119
- d_model (int): Embedding dimension.
120
- dropout_rate (float): Dropout rate.
121
- max_len (int): Maximum input length.
122
- """
123
- def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
124
- """Initialize class."""
125
- super().__init__(d_model, dropout_rate, max_len, reverse=True)
126
-
127
- def forward(self,
128
- x: torch.Tensor,
129
- offset: Union[int, torch.Tensor] = 0) \
130
- -> Tuple[torch.Tensor, torch.Tensor]:
131
- """Compute positional encoding.
132
- Args:
133
- x (torch.Tensor): Input tensor (batch, time, `*`).
134
- Returns:
135
- torch.Tensor: Encoded tensor (batch, time, `*`).
136
- torch.Tensor: Positional embedding tensor (1, time, `*`).
137
- """
138
- self.pe = self.pe.to(x.device)
139
- x = x * self.xscale
140
- pos_emb = self.position_encoding(offset, x.size(1), False)
141
- return self.dropout(x), self.dropout(pos_emb)
142
-
143
-
144
- class NoPositionalEncoding(torch.nn.Module):
145
- """ No position encoding
146
- """
147
- def __init__(self, d_model: int, dropout_rate: float):
148
- super().__init__()
149
- self.d_model = d_model
150
- self.dropout = torch.nn.Dropout(p=dropout_rate)
151
-
152
- def forward(self,
153
- x: torch.Tensor,
154
- offset: Union[int, torch.Tensor] = 0) \
155
- -> Tuple[torch.Tensor, torch.Tensor]:
156
- """ Just return zero vector for interface compatibility
157
- """
158
- pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
159
- return self.dropout(x), pos_emb
160
-
161
- def position_encoding(
162
- self, offset: Union[int, torch.Tensor], size: int) -> torch.Tensor:
163
- return torch.zeros(1, size, self.d_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/gpt/conformer/subsampling.py DELETED
@@ -1,348 +0,0 @@
1
- # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # Modified from ESPnet(https://github.com/espnet/espnet)
15
-
16
-
17
- """Subsampling layer definition."""
18
-
19
- from typing import Tuple, Union
20
-
21
- import torch
22
-
23
-
24
- class BaseSubsampling(torch.nn.Module):
25
- def __init__(self):
26
- super().__init__()
27
- self.right_context = 0
28
- self.subsampling_rate = 1
29
-
30
- def position_encoding(self, offset: Union[int, torch.Tensor],
31
- size: int) -> torch.Tensor:
32
- return self.pos_enc.position_encoding(offset, size)
33
-
34
-
35
- class LinearNoSubsampling(BaseSubsampling):
36
- """Linear transform the input without subsampling
37
-
38
- Args:
39
- idim (int): Input dimension.
40
- odim (int): Output dimension.
41
- dropout_rate (float): Dropout rate.
42
-
43
- """
44
- def __init__(self, idim: int, odim: int, dropout_rate: float,
45
- pos_enc_class: torch.nn.Module):
46
- """Construct an linear object."""
47
- super().__init__()
48
- self.out = torch.nn.Sequential(
49
- torch.nn.Linear(idim, odim),
50
- torch.nn.LayerNorm(odim, eps=1e-5),
51
- torch.nn.Dropout(dropout_rate),
52
- )
53
- self.pos_enc = pos_enc_class
54
- self.right_context = 0
55
- self.subsampling_rate = 1
56
-
57
- def forward(
58
- self,
59
- x: torch.Tensor,
60
- x_mask: torch.Tensor,
61
- offset: Union[int, torch.Tensor] = 0
62
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
63
- """Input x.
64
-
65
- Args:
66
- x (torch.Tensor): Input tensor (#batch, time, idim).
67
- x_mask (torch.Tensor): Input mask (#batch, 1, time).
68
-
69
- Returns:
70
- torch.Tensor: linear input tensor (#batch, time', odim),
71
- where time' = time .
72
- torch.Tensor: linear input mask (#batch, 1, time'),
73
- where time' = time .
74
-
75
- """
76
- x = self.out(x)
77
- x, pos_emb = self.pos_enc(x, offset)
78
- return x, pos_emb, x_mask
79
-
80
-
81
- class Conv2dSubsampling3(BaseSubsampling):
82
- """Convolutional 2D subsampling (to 1/3 length).
83
-
84
- Args:
85
- idim (int): Input dimension.
86
- odim (int): Output dimension.
87
- dropout_rate (float): Dropout rate.
88
-
89
- """
90
- def __init__(self, idim: int, odim: int, dropout_rate: float,
91
- pos_enc_class: torch.nn.Module):
92
- """Construct an Conv2dSubsampling3 object."""
93
- super().__init__()
94
- self.conv = torch.nn.Sequential(
95
- torch.nn.Conv2d(1, odim, 5, 3),
96
- torch.nn.ReLU()
97
- )
98
- self.out = torch.nn.Sequential(
99
- torch.nn.Linear(odim * ((idim - 2) // 3), odim))
100
- self.pos_enc = pos_enc_class
101
- # The right context for every conv layer is computed by:
102
- # (kernel_size - 1) * frame_rate_of_this_layer
103
- self.subsampling_rate = 3
104
- # 4 = (5 - 1) * 1
105
- self.right_context = 4
106
-
107
- def forward(
108
- self,
109
- x: torch.Tensor,
110
- x_mask: torch.Tensor,
111
- offset: Union[int, torch.Tensor] = 0
112
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
113
- """Subsample x.
114
-
115
- Args:
116
- x (torch.Tensor): Input tensor (#batch, time, idim).
117
- x_mask (torch.Tensor): Input mask (#batch, 1, time).
118
-
119
- Returns:
120
- torch.Tensor: Subsampled tensor (#batch, time', odim),
121
- where time' = time // 3.
122
- torch.Tensor: Subsampled mask (#batch, 1, time'),
123
- where time' = time // 3.
124
- torch.Tensor: positional encoding
125
-
126
- """
127
- x = x.unsqueeze(1) # (b, c=1, t, f)
128
- x = self.conv(x)
129
- b, c, t, f = x.size()
130
- x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
131
- x, pos_emb = self.pos_enc(x, offset)
132
- return x, pos_emb, x_mask[:, :, :-2:3]
133
-
134
-
135
- class Conv2dSubsampling2(BaseSubsampling):
136
- """Convolutional 2D subsampling (to 1/2 length).
137
-
138
- Args:
139
- idim (int): Input dimension.
140
- odim (int): Output dimension.
141
- dropout_rate (float): Dropout rate.
142
-
143
- """
144
- def __init__(self, idim: int, odim: int, dropout_rate: float,
145
- pos_enc_class: torch.nn.Module):
146
- """Construct an Conv2dSubsampling4 object."""
147
- super().__init__()
148
- self.conv = torch.nn.Sequential(
149
- torch.nn.Conv2d(1, odim, 3, 2),
150
- torch.nn.ReLU(),
151
- )
152
- self.out = torch.nn.Sequential(
153
- torch.nn.Linear(odim * ((idim - 1) // 2), odim))
154
- self.pos_enc = pos_enc_class
155
- # The right context for every conv layer is computed by:
156
- # (kernel_size - 1) * frame_rate_of_this_layer
157
- self.subsampling_rate = 2
158
- # 2 = (3 - 1) * 1
159
- self.right_context = 2
160
-
161
- def forward(
162
- self,
163
- x: torch.Tensor,
164
- x_mask: torch.Tensor,
165
- offset: Union[int, torch.Tensor] = 0
166
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
167
- """Subsample x.
168
-
169
- Args:
170
- x (torch.Tensor): Input tensor (#batch, time, idim).
171
- x_mask (torch.Tensor): Input mask (#batch, 1, time).
172
-
173
- Returns:
174
- torch.Tensor: Subsampled tensor (#batch, time', odim),
175
- where time' = time // 2.
176
- torch.Tensor: Subsampled mask (#batch, 1, time'),
177
- where time' = time // 2.
178
- torch.Tensor: positional encoding
179
-
180
- """
181
- x = x.unsqueeze(1) # (b, c=1, t, f)
182
- x = self.conv(x)
183
- b, c, t, f = x.size()
184
- x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
185
- x, pos_emb = self.pos_enc(x, offset)
186
- return x, pos_emb, x_mask[:, :, 2::2]
187
-
188
-
189
- class Conv2dSubsampling4(BaseSubsampling):
190
- """Convolutional 2D subsampling (to 1/4 length).
191
-
192
- Args:
193
- idim (int): Input dimension.
194
- odim (int): Output dimension.
195
- dropout_rate (float): Dropout rate.
196
-
197
- """
198
- def __init__(self, idim: int, odim: int, dropout_rate: float,
199
- pos_enc_class: torch.nn.Module):
200
- """Construct an Conv2dSubsampling4 object."""
201
- super().__init__()
202
- self.conv = torch.nn.Sequential(
203
- torch.nn.Conv2d(1, odim, 3, 2),
204
- torch.nn.ReLU(),
205
- torch.nn.Conv2d(odim, odim, 3, 2),
206
- torch.nn.ReLU(),
207
- )
208
- self.out = torch.nn.Sequential(
209
- torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
210
- self.pos_enc = pos_enc_class
211
- # The right context for every conv layer is computed by:
212
- # (kernel_size - 1) * frame_rate_of_this_layer
213
- self.subsampling_rate = 4
214
- # 6 = (3 - 1) * 1 + (3 - 1) * 2
215
- self.right_context = 6
216
-
217
- def forward(
218
- self,
219
- x: torch.Tensor,
220
- x_mask: torch.Tensor,
221
- offset: Union[int, torch.Tensor] = 0
222
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
223
- """Subsample x.
224
-
225
- Args:
226
- x (torch.Tensor): Input tensor (#batch, time, idim).
227
- x_mask (torch.Tensor): Input mask (#batch, 1, time).
228
-
229
- Returns:
230
- torch.Tensor: Subsampled tensor (#batch, time', odim),
231
- where time' = time // 4.
232
- torch.Tensor: Subsampled mask (#batch, 1, time'),
233
- where time' = time // 4.
234
- torch.Tensor: positional encoding
235
-
236
- """
237
- x = x.unsqueeze(1) # (b, c=1, t, f)
238
- x = self.conv(x)
239
- b, c, t, f = x.size()
240
- x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
241
- x, pos_emb = self.pos_enc(x, offset)
242
- return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
243
-
244
-
245
- class Conv2dSubsampling6(BaseSubsampling):
246
- """Convolutional 2D subsampling (to 1/6 length).
247
- Args:
248
- idim (int): Input dimension.
249
- odim (int): Output dimension.
250
- dropout_rate (float): Dropout rate.
251
- pos_enc (torch.nn.Module): Custom position encoding layer.
252
- """
253
- def __init__(self, idim: int, odim: int, dropout_rate: float,
254
- pos_enc_class: torch.nn.Module):
255
- """Construct an Conv2dSubsampling6 object."""
256
- super().__init__()
257
- self.conv = torch.nn.Sequential(
258
- torch.nn.Conv2d(1, odim, 3, 2),
259
- torch.nn.ReLU(),
260
- torch.nn.Conv2d(odim, odim, 5, 3),
261
- torch.nn.ReLU(),
262
- )
263
- self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3),
264
- odim)
265
- self.pos_enc = pos_enc_class
266
- # 10 = (3 - 1) * 1 + (5 - 1) * 2
267
- self.subsampling_rate = 6
268
- self.right_context = 10
269
-
270
- def forward(
271
- self,
272
- x: torch.Tensor,
273
- x_mask: torch.Tensor,
274
- offset: Union[int, torch.Tensor] = 0
275
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
276
- """Subsample x.
277
- Args:
278
- x (torch.Tensor): Input tensor (#batch, time, idim).
279
- x_mask (torch.Tensor): Input mask (#batch, 1, time).
280
-
281
- Returns:
282
- torch.Tensor: Subsampled tensor (#batch, time', odim),
283
- where time' = time // 6.
284
- torch.Tensor: Subsampled mask (#batch, 1, time'),
285
- where time' = time // 6.
286
- torch.Tensor: positional encoding
287
- """
288
- x = x.unsqueeze(1) # (b, c, t, f)
289
- x = self.conv(x)
290
- b, c, t, f = x.size()
291
- x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
292
- x, pos_emb = self.pos_enc(x, offset)
293
- return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
294
-
295
-
296
- class Conv2dSubsampling8(BaseSubsampling):
297
- """Convolutional 2D subsampling (to 1/8 length).
298
-
299
- Args:
300
- idim (int): Input dimension.
301
- odim (int): Output dimension.
302
- dropout_rate (float): Dropout rate.
303
-
304
- """
305
- def __init__(self, idim: int, odim: int, dropout_rate: float,
306
- pos_enc_class: torch.nn.Module):
307
- """Construct an Conv2dSubsampling8 object."""
308
- super().__init__()
309
- self.conv = torch.nn.Sequential(
310
- torch.nn.Conv2d(1, odim, 3, 2),
311
- torch.nn.ReLU(),
312
- torch.nn.Conv2d(odim, odim, 3, 2),
313
- torch.nn.ReLU(),
314
- torch.nn.Conv2d(odim, odim, 3, 2),
315
- torch.nn.ReLU(),
316
- )
317
- self.linear = torch.nn.Linear(
318
- odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
319
- self.pos_enc = pos_enc_class
320
- self.subsampling_rate = 8
321
- # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
322
- self.right_context = 14
323
-
324
- def forward(
325
- self,
326
- x: torch.Tensor,
327
- x_mask: torch.Tensor,
328
- offset: Union[int, torch.Tensor] = 0
329
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
330
- """Subsample x.
331
-
332
- Args:
333
- x (torch.Tensor): Input tensor (#batch, time, idim).
334
- x_mask (torch.Tensor): Input mask (#batch, 1, time).
335
-
336
- Returns:
337
- torch.Tensor: Subsampled tensor (#batch, time', odim),
338
- where time' = time // 8.
339
- torch.Tensor: Subsampled mask (#batch, 1, time'),
340
- where time' = time // 8.
341
- torch.Tensor: positional encoding
342
- """
343
- x = x.unsqueeze(1) # (b, c, t, f)
344
- x = self.conv(x)
345
- b, c, t, f = x.size()
346
- x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
347
- x, pos_emb = self.pos_enc(x, offset)
348
- return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/gpt/conformer_encoder.py DELETED
@@ -1,520 +0,0 @@
1
-
2
- from typing import Optional, Tuple
3
-
4
- import torch
5
- import torch.nn as nn
6
-
7
- from indextts.gpt.conformer.attention import (MultiHeadedAttention,
8
- RelPositionMultiHeadedAttention)
9
- from indextts.gpt.conformer.embedding import (NoPositionalEncoding,
10
- PositionalEncoding,
11
- RelPositionalEncoding)
12
- from indextts.gpt.conformer.subsampling import (Conv2dSubsampling2,
13
- Conv2dSubsampling4,
14
- Conv2dSubsampling6,
15
- Conv2dSubsampling8,
16
- LinearNoSubsampling)
17
- from indextts.utils.common import make_pad_mask
18
-
19
-
20
- class PositionwiseFeedForward(torch.nn.Module):
21
- """Positionwise feed forward layer.
22
-
23
- FeedForward are appied on each position of the sequence.
24
- The output dim is same with the input dim.
25
-
26
- Args:
27
- idim (int): Input dimenstion.
28
- hidden_units (int): The number of hidden units.
29
- dropout_rate (float): Dropout rate.
30
- activation (torch.nn.Module): Activation function
31
- """
32
-
33
- def __init__(self,
34
- idim: int,
35
- hidden_units: int,
36
- dropout_rate: float,
37
- activation: torch.nn.Module = torch.nn.ReLU()):
38
- """Construct a PositionwiseFeedForward object."""
39
- super(PositionwiseFeedForward, self).__init__()
40
- self.w_1 = torch.nn.Linear(idim, hidden_units)
41
- self.activation = activation
42
- self.dropout = torch.nn.Dropout(dropout_rate)
43
- self.w_2 = torch.nn.Linear(hidden_units, idim)
44
-
45
- def forward(self, xs: torch.Tensor) -> torch.Tensor:
46
- """Forward function.
47
-
48
- Args:
49
- xs: input tensor (B, L, D)
50
- Returns:
51
- output tensor, (B, L, D)
52
- """
53
- return self.w_2(self.dropout(self.activation(self.w_1(xs))))
54
-
55
-
56
- class ConvolutionModule(nn.Module):
57
- """ConvolutionModule in Conformer model."""
58
-
59
- def __init__(self,
60
- channels: int,
61
- kernel_size: int = 15,
62
- activation: nn.Module = nn.ReLU(),
63
- bias: bool = True):
64
- """Construct an ConvolutionModule object.
65
- Args:
66
- channels (int): The number of channels of conv layers.
67
- kernel_size (int): Kernel size of conv layers.
68
- causal (int): Whether use causal convolution or not
69
- """
70
- super().__init__()
71
-
72
- self.pointwise_conv1 = nn.Conv1d(
73
- channels,
74
- 2 * channels,
75
- kernel_size=1,
76
- stride=1,
77
- padding=0,
78
- bias=bias,
79
- )
80
- # self.lorder is used to distinguish if it's a causal convolution,
81
- # if self.lorder > 0: it's a causal convolution, the input will be
82
- # padded with self.lorder frames on the left in forward.
83
- # else: it's a symmetrical convolution
84
- # kernel_size should be an odd number for none causal convolution
85
- assert (kernel_size - 1) % 2 == 0
86
- padding = (kernel_size - 1) // 2
87
- self.lorder = 0
88
-
89
- self.depthwise_conv = nn.Conv1d(
90
- channels,
91
- channels,
92
- kernel_size,
93
- stride=1,
94
- padding=padding,
95
- groups=channels,
96
- bias=bias,
97
- )
98
-
99
- self.use_layer_norm = True
100
- self.norm = nn.LayerNorm(channels)
101
-
102
- self.pointwise_conv2 = nn.Conv1d(
103
- channels,
104
- channels,
105
- kernel_size=1,
106
- stride=1,
107
- padding=0,
108
- bias=bias,
109
- )
110
- self.activation = activation
111
-
112
- def forward(
113
- self,
114
- x: torch.Tensor,
115
- mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
116
- cache: torch.Tensor = torch.zeros((0, 0, 0)),
117
- ) -> Tuple[torch.Tensor, torch.Tensor]:
118
- """Compute convolution module.
119
- Args:
120
- x (torch.Tensor): Input tensor (#batch, time, channels).
121
- mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
122
- (0, 0, 0) means fake mask.
123
- cache (torch.Tensor): left context cache, it is only
124
- used in causal convolution (#batch, channels, cache_t),
125
- (0, 0, 0) meas fake cache.
126
- Returns:
127
- torch.Tensor: Output tensor (#batch, time, channels).
128
- """
129
- # exchange the temporal dimension and the feature dimension
130
- x = x.transpose(1, 2) # (#batch, channels, time)
131
-
132
- # mask batch padding
133
- if mask_pad.size(2) > 0: # time > 0
134
- x.masked_fill_(~mask_pad, 0.0)
135
-
136
- if self.lorder > 0:
137
- if cache.size(2) == 0: # cache_t == 0
138
- x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
139
- else:
140
- assert cache.size(0) == x.size(0) # equal batch
141
- assert cache.size(1) == x.size(1) # equal channel
142
- x = torch.cat((cache, x), dim=2)
143
- assert (x.size(2) > self.lorder)
144
- new_cache = x[:, :, -self.lorder:]
145
- else:
146
- # It's better we just return None if no cache is required,
147
- # However, for JIT export, here we just fake one tensor instead of
148
- # None.
149
- new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
150
-
151
- # GLU mechanism
152
- x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
153
- x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
154
-
155
- # 1D Depthwise Conv
156
- x = self.depthwise_conv(x)
157
- if self.use_layer_norm:
158
- x = x.transpose(1, 2)
159
- x = self.activation(self.norm(x))
160
- if self.use_layer_norm:
161
- x = x.transpose(1, 2)
162
- x = self.pointwise_conv2(x)
163
- # mask batch padding
164
- if mask_pad.size(2) > 0: # time > 0
165
- x.masked_fill_(~mask_pad, 0.0)
166
-
167
- return x.transpose(1, 2), new_cache
168
-
169
-
170
- class ConformerEncoderLayer(nn.Module):
171
- """Encoder layer module.
172
- Args:
173
- size (int): Input dimension.
174
- self_attn (torch.nn.Module): Self-attention module instance.
175
- `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
176
- instance can be used as the argument.
177
- feed_forward (torch.nn.Module): Feed-forward module instance.
178
- `PositionwiseFeedForward` instance can be used as the argument.
179
- feed_forward_macaron (torch.nn.Module): Additional feed-forward module
180
- instance.
181
- `PositionwiseFeedForward` instance can be used as the argument.
182
- conv_module (torch.nn.Module): Convolution module instance.
183
- `ConvlutionModule` instance can be used as the argument.
184
- dropout_rate (float): Dropout rate.
185
- normalize_before (bool):
186
- True: use layer_norm before each sub-block.
187
- False: use layer_norm after each sub-block.
188
- concat_after (bool): Whether to concat attention layer's input and
189
- output.
190
- True: x -> x + linear(concat(x, att(x)))
191
- False: x -> x + att(x)
192
- """
193
-
194
- def __init__(
195
- self,
196
- size: int,
197
- self_attn: torch.nn.Module,
198
- feed_forward: Optional[nn.Module] = None,
199
- feed_forward_macaron: Optional[nn.Module] = None,
200
- conv_module: Optional[nn.Module] = None,
201
- dropout_rate: float = 0.1,
202
- normalize_before: bool = True,
203
- concat_after: bool = False,
204
- ):
205
- """Construct an EncoderLayer object."""
206
- super().__init__()
207
- self.self_attn = self_attn
208
- self.feed_forward = feed_forward
209
- self.feed_forward_macaron = feed_forward_macaron
210
- self.conv_module = conv_module
211
- self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module
212
- self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module
213
- if feed_forward_macaron is not None:
214
- self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
215
- self.ff_scale = 0.5
216
- else:
217
- self.ff_scale = 1.0
218
- if self.conv_module is not None:
219
- self.norm_conv = nn.LayerNorm(size,
220
- eps=1e-5) # for the CNN module
221
- self.norm_final = nn.LayerNorm(
222
- size, eps=1e-5) # for the final output of the block
223
- self.dropout = nn.Dropout(dropout_rate)
224
- self.size = size
225
- self.normalize_before = normalize_before
226
- self.concat_after = concat_after
227
- if self.concat_after:
228
- self.concat_linear = nn.Linear(size + size, size)
229
- else:
230
- self.concat_linear = nn.Identity()
231
-
232
- def forward(
233
- self,
234
- x: torch.Tensor,
235
- mask: torch.Tensor,
236
- pos_emb: torch.Tensor,
237
- mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
238
- att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
239
- cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
240
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
241
- """Compute encoded features.
242
-
243
- Args:
244
- x (torch.Tensor): (#batch, time, size)
245
- mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
246
- (0, 0, 0) means fake mask.
247
- pos_emb (torch.Tensor): positional encoding, must not be None
248
- for ConformerEncoderLayer.
249
- mask_pad (torch.Tensor): batch padding mask used for conv module.
250
- (#batch, 1,time), (0, 0, 0) means fake mask.
251
- att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
252
- (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
253
- cnn_cache (torch.Tensor): Convolution cache in conformer layer
254
- (#batch=1, size, cache_t2)
255
- Returns:
256
- torch.Tensor: Output tensor (#batch, time, size).
257
- torch.Tensor: Mask tensor (#batch, time, time).
258
- torch.Tensor: att_cache tensor,
259
- (#batch=1, head, cache_t1 + time, d_k * 2).
260
- torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
261
- """
262
-
263
- # whether to use macaron style
264
- if self.feed_forward_macaron is not None:
265
- residual = x
266
- if self.normalize_before:
267
- x = self.norm_ff_macaron(x)
268
- x = residual + self.ff_scale * self.dropout(
269
- self.feed_forward_macaron(x))
270
- if not self.normalize_before:
271
- x = self.norm_ff_macaron(x)
272
-
273
- # multi-headed self-attention module
274
- residual = x
275
- if self.normalize_before:
276
- x = self.norm_mha(x)
277
-
278
- x_att, new_att_cache = self.self_attn(
279
- x, x, x, mask, pos_emb, att_cache)
280
- if self.concat_after:
281
- x_concat = torch.cat((x, x_att), dim=-1)
282
- x = residual + self.concat_linear(x_concat)
283
- else:
284
- x = residual + self.dropout(x_att)
285
- if not self.normalize_before:
286
- x = self.norm_mha(x)
287
-
288
- # convolution module
289
- # Fake new cnn cache here, and then change it in conv_module
290
- new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
291
- if self.conv_module is not None:
292
- residual = x
293
- if self.normalize_before:
294
- x = self.norm_conv(x)
295
- x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
296
- x = residual + self.dropout(x)
297
-
298
- if not self.normalize_before:
299
- x = self.norm_conv(x)
300
-
301
- # feed forward module
302
- residual = x
303
- if self.normalize_before:
304
- x = self.norm_ff(x)
305
-
306
- x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
307
- if not self.normalize_before:
308
- x = self.norm_ff(x)
309
-
310
- if self.conv_module is not None:
311
- x = self.norm_final(x)
312
-
313
- return x, mask, new_att_cache, new_cnn_cache
314
-
315
-
316
- class BaseEncoder(torch.nn.Module):
317
- def __init__(
318
- self,
319
- input_size: int,
320
- output_size: int = 256,
321
- attention_heads: int = 4,
322
- linear_units: int = 2048,
323
- num_blocks: int = 6,
324
- dropout_rate: float = 0.0,
325
- input_layer: str = "conv2d",
326
- pos_enc_layer_type: str = "abs_pos",
327
- normalize_before: bool = True,
328
- concat_after: bool = False,
329
- ):
330
- """
331
- Args:
332
- input_size (int): input dim
333
- output_size (int): dimension of attention
334
- attention_heads (int): the number of heads of multi head attention
335
- linear_units (int): the hidden units number of position-wise feed
336
- forward
337
- num_blocks (int): the number of decoder blocks
338
- dropout_rate (float): dropout rate
339
- attention_dropout_rate (float): dropout rate in attention
340
- positional_dropout_rate (float): dropout rate after adding
341
- positional encoding
342
- input_layer (str): input layer type.
343
- optional [linear, conv2d, conv2d6, conv2d8]
344
- pos_enc_layer_type (str): Encoder positional encoding layer type.
345
- opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
346
- normalize_before (bool):
347
- True: use layer_norm before each sub-block of a layer.
348
- False: use layer_norm after each sub-block of a layer.
349
- concat_after (bool): whether to concat attention layer's input
350
- and output.
351
- True: x -> x + linear(concat(x, att(x)))
352
- False: x -> x + att(x)
353
- static_chunk_size (int): chunk size for static chunk training and
354
- decoding
355
- use_dynamic_chunk (bool): whether use dynamic chunk size for
356
- training or not, You can only use fixed chunk(chunk_size > 0)
357
- or dyanmic chunk size(use_dynamic_chunk = True)
358
- global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
359
- use_dynamic_left_chunk (bool): whether use dynamic left chunk in
360
- dynamic chunk training
361
- """
362
- super().__init__()
363
- self._output_size = output_size
364
-
365
- if pos_enc_layer_type == "abs_pos":
366
- pos_enc_class = PositionalEncoding
367
- elif pos_enc_layer_type == "rel_pos":
368
- pos_enc_class = RelPositionalEncoding
369
- elif pos_enc_layer_type == "no_pos":
370
- pos_enc_class = NoPositionalEncoding
371
- else:
372
- raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
373
-
374
- if input_layer == "linear":
375
- subsampling_class = LinearNoSubsampling
376
- elif input_layer == "conv2d2":
377
- subsampling_class = Conv2dSubsampling2
378
- elif input_layer == "conv2d":
379
- subsampling_class = Conv2dSubsampling4
380
- elif input_layer == "conv2d6":
381
- subsampling_class = Conv2dSubsampling6
382
- elif input_layer == "conv2d8":
383
- subsampling_class = Conv2dSubsampling8
384
- else:
385
- raise ValueError("unknown input_layer: " + input_layer)
386
-
387
- self.embed = subsampling_class(
388
- input_size,
389
- output_size,
390
- dropout_rate,
391
- pos_enc_class(output_size, dropout_rate),
392
- )
393
-
394
- self.normalize_before = normalize_before
395
- self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
396
-
397
- def output_size(self) -> int:
398
- return self._output_size
399
-
400
- def forward(
401
- self,
402
- xs: torch.Tensor,
403
- xs_lens: torch.Tensor,
404
- ) -> Tuple[torch.Tensor, torch.Tensor]:
405
- """Embed positions in tensor.
406
-
407
- Args:
408
- xs: padded input tensor (B, T, D)
409
- xs_lens: input length (B)
410
- decoding_chunk_size: decoding chunk size for dynamic chunk
411
- 0: default for training, use random dynamic chunk.
412
- <0: for decoding, use full chunk.
413
- >0: for decoding, use fixed chunk size as set.
414
- num_decoding_left_chunks: number of left chunks, this is for decoding,
415
- the chunk size is decoding_chunk_size.
416
- >=0: use num_decoding_left_chunks
417
- <0: use all left chunks
418
- Returns:
419
- encoder output tensor xs, and subsampled masks
420
- xs: padded output tensor (B, T' ~= T/subsample_rate, D)
421
- masks: torch.Tensor batch padding mask after subsample
422
- (B, 1, T' ~= T/subsample_rate)
423
- """
424
- T = xs.size(1)
425
- masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
426
- xs, pos_emb, masks = self.embed(xs, masks)
427
- chunk_masks = masks
428
- mask_pad = masks # (B, 1, T/subsample_rate)
429
- for layer in self.encoders:
430
- xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
431
- if self.normalize_before:
432
- xs = self.after_norm(xs)
433
- # Here we assume the mask is not changed in encoder layers, so just
434
- # return the masks before encoder layers, and the masks will be used
435
- # for cross attention with decoder later
436
- return xs, masks
437
-
438
-
439
- class ConformerEncoder(BaseEncoder):
440
- """Conformer encoder module."""
441
-
442
- def __init__(
443
- self,
444
- input_size: int,
445
- output_size: int = 256,
446
- attention_heads: int = 4,
447
- linear_units: int = 2048,
448
- num_blocks: int = 6,
449
- dropout_rate: float = 0.0,
450
- input_layer: str = "conv2d",
451
- pos_enc_layer_type: str = "rel_pos",
452
- normalize_before: bool = True,
453
- concat_after: bool = False,
454
- macaron_style: bool = False,
455
- use_cnn_module: bool = True,
456
- cnn_module_kernel: int = 15,
457
- ):
458
- """Construct ConformerEncoder
459
-
460
- Args:
461
- input_size to use_dynamic_chunk, see in BaseEncoder
462
- positionwise_conv_kernel_size (int): Kernel size of positionwise
463
- conv1d layer.
464
- macaron_style (bool): Whether to use macaron style for
465
- positionwise layer.
466
- selfattention_layer_type (str): Encoder attention layer type,
467
- the parameter has no effect now, it's just for configure
468
- compatibility.
469
- activation_type (str): Encoder activation function type.
470
- use_cnn_module (bool): Whether to use convolution module.
471
- cnn_module_kernel (int): Kernel size of convolution module.
472
- causal (bool): whether to use causal convolution or not.
473
- """
474
-
475
- super().__init__(input_size, output_size, attention_heads,
476
- linear_units, num_blocks, dropout_rate,
477
- input_layer, pos_enc_layer_type, normalize_before,
478
- concat_after)
479
-
480
- activation = torch.nn.SiLU()
481
-
482
- # self-attention module definition
483
- if pos_enc_layer_type != "rel_pos":
484
- encoder_selfattn_layer = MultiHeadedAttention
485
- else:
486
- encoder_selfattn_layer = RelPositionMultiHeadedAttention
487
- encoder_selfattn_layer_args = (
488
- attention_heads,
489
- output_size,
490
- dropout_rate,
491
- )
492
-
493
- # feed-forward module definition
494
- positionwise_layer = PositionwiseFeedForward
495
- positionwise_layer_args = (
496
- output_size,
497
- linear_units,
498
- dropout_rate,
499
- activation,
500
- )
501
- # convolution module definition
502
- convolution_layer = ConvolutionModule
503
- convolution_layer_args = (output_size,
504
- cnn_module_kernel,
505
- activation,)
506
-
507
- self.encoders = torch.nn.ModuleList([
508
- ConformerEncoderLayer(
509
- output_size,
510
- encoder_selfattn_layer(*encoder_selfattn_layer_args),
511
- positionwise_layer(*positionwise_layer_args),
512
- positionwise_layer(
513
- *positionwise_layer_args) if macaron_style else None,
514
- convolution_layer(
515
- *convolution_layer_args) if use_cnn_module else None,
516
- dropout_rate,
517
- normalize_before,
518
- concat_after,
519
- ) for _ in range(num_blocks)
520
- ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/gpt/model.py DELETED
@@ -1,713 +0,0 @@
1
- import functools
2
-
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
-
7
- import transformers
8
- from transformers import GPT2Config, LogitsProcessorList
9
- from indextts.gpt.transformers_gpt2 import GPT2PreTrainedModel, GPT2Model
10
-
11
- # from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
12
- from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
13
- from transformers.utils.model_parallel_utils import (assert_device_map,
14
- get_device_map)
15
-
16
- from indextts.gpt.conformer_encoder import ConformerEncoder
17
- from indextts.gpt.perceiver import PerceiverResampler
18
- from indextts.utils.arch_util import AttentionBlock
19
- from indextts.utils.typical_sampling import TypicalLogitsWarper
20
-
21
-
22
- def null_position_embeddings(range, dim):
23
- return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
24
-
25
-
26
- class ResBlock(nn.Module):
27
- """
28
- Basic residual convolutional block that uses GroupNorm.
29
- """
30
-
31
- def __init__(self, chan):
32
- super().__init__()
33
- self.net = nn.Sequential(
34
- nn.Conv1d(chan, chan, kernel_size=3, padding=1),
35
- nn.GroupNorm(chan // 8, chan),
36
- nn.ReLU(),
37
- nn.Conv1d(chan, chan, kernel_size=3, padding=1),
38
- nn.GroupNorm(chan // 8, chan)
39
- )
40
-
41
- def forward(self, x):
42
- return F.relu(self.net(x) + x)
43
-
44
-
45
- class GPT2InferenceModel(GPT2PreTrainedModel):
46
- def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache=False):
47
- super().__init__(config)
48
- # Note: the argument named `text_pos_emb` here actually represents the mel position embedding
49
- self.transformer = gpt
50
- self.text_pos_embedding = text_pos_emb
51
- self.embeddings = embeddings
52
- self.final_norm = norm
53
- self.lm_head = nn.Sequential(norm, linear)
54
- self.kv_cache = kv_cache
55
-
56
- # Model parallel
57
- self.model_parallel = False
58
- self.device_map = None
59
- self.cached_mel_emb = None
60
-
61
- def parallelize(self, device_map=None):
62
- self.device_map = (
63
- get_device_map(len(self.transformer.h), range(max(1, torch.cuda.device_count())))
64
- if device_map is None
65
- else device_map
66
- )
67
- assert_device_map(self.device_map, len(self.transformer.h))
68
- self.transformer.parallelize(self.device_map)
69
- self.lm_head = self.lm_head.to(self.transformer.first_device)
70
- self.model_parallel = True
71
-
72
- def deparallelize(self):
73
- self.transformer.deparallelize()
74
- self.transformer = self.transformer.to("cpu")
75
- self.lm_head = self.lm_head.to("cpu")
76
- self.model_parallel = False
77
- torch.cuda.empty_cache()
78
- if torch.backends.mps.is_available():
79
- torch.mps.empty_cache()
80
-
81
- def get_output_embeddings(self):
82
- return self.lm_head
83
-
84
- def set_output_embeddings(self, new_embeddings):
85
- self.lm_head = new_embeddings
86
-
87
- def store_mel_emb(self, mel_emb):
88
- self.cached_mel_emb = mel_emb
89
-
90
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
91
- token_type_ids = kwargs.get("token_type_ids", None) # usually None
92
- if not self.kv_cache:
93
- past_key_values = None
94
- # only last token for inputs_ids if past is defined in kwargs
95
- if past_key_values:
96
- input_ids = input_ids[:, -1].unsqueeze(-1)
97
- if token_type_ids is not None:
98
- token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
99
-
100
- attention_mask = kwargs.get("attention_mask", None)
101
- position_ids = kwargs.get("position_ids", None)
102
-
103
- if attention_mask is not None and position_ids is None:
104
- # create position_ids on the fly for batch generation
105
- position_ids = attention_mask.long().cumsum(-1) - 1
106
- position_ids.masked_fill_(attention_mask == 0, 0)
107
- if past_key_values:
108
- position_ids = position_ids[:, -1].unsqueeze(-1)
109
- else:
110
- position_ids = None
111
- return {
112
- "input_ids": input_ids,
113
- "past_key_values": past_key_values,
114
- "use_cache": kwargs.get("use_cache"),
115
- "position_ids": position_ids,
116
- "attention_mask": attention_mask,
117
- "token_type_ids": token_type_ids,
118
- }
119
-
120
- def forward(
121
- self,
122
- input_ids=None,
123
- past_key_values=None,
124
- attention_mask=None,
125
- token_type_ids=None,
126
- position_ids=None,
127
- head_mask=None,
128
- inputs_embeds=None,
129
- encoder_hidden_states=None,
130
- encoder_attention_mask=None,
131
- labels=None,
132
- use_cache=None,
133
- output_attentions=None,
134
- output_hidden_states=None,
135
- return_dict=None,
136
- ):
137
- assert self.cached_mel_emb is not None
138
- assert inputs_embeds is None # Not supported by this inference model.
139
- assert labels is None # Training not supported by this inference model.
140
- return_dict = (
141
- return_dict if return_dict is not None else self.config.use_return_dict
142
- )
143
- # Create embedding
144
- mel_len = self.cached_mel_emb.shape[1]
145
- if input_ids.shape[1] != 1:
146
- text_inputs = input_ids[:, mel_len:]
147
- text_emb = self.embeddings(text_inputs)
148
- text_emb = text_emb + self.text_pos_embedding(text_emb)
149
- if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
150
- mel_emb = self.cached_mel_emb.repeat_interleave(
151
- text_emb.shape[0] // self.cached_mel_emb.shape[0], 0
152
- )
153
- else: # this outcome only occurs once per loop in most cases
154
- mel_emb = self.cached_mel_emb
155
- emb = torch.cat([mel_emb, text_emb], dim=1)
156
- else:
157
- emb = self.embeddings(input_ids)
158
- emb = emb + self.text_pos_embedding.get_fixed_embedding(
159
- attention_mask.shape[1] - mel_len, attention_mask.device
160
- )
161
- transformer_outputs = self.transformer(
162
- inputs_embeds=emb,
163
- past_key_values=past_key_values,
164
- attention_mask=attention_mask,
165
- token_type_ids=token_type_ids,
166
- position_ids=position_ids,
167
- head_mask=head_mask,
168
- encoder_hidden_states=encoder_hidden_states,
169
- encoder_attention_mask=encoder_attention_mask,
170
- use_cache=use_cache,
171
- output_attentions=output_attentions,
172
- output_hidden_states=output_hidden_states,
173
- return_dict=return_dict,
174
- )
175
- hidden_states = transformer_outputs[0]
176
-
177
- # Set device for model parallelism
178
- if self.model_parallel:
179
- if torch.backends.mps.is_available():
180
- self.to(self.transformer.first_device)
181
- else:
182
- torch.cuda.set_device(self.transformer.first_device)
183
- hidden_states = hidden_states.to(self.lm_head.weight.device)
184
-
185
- lm_logits = self.lm_head(hidden_states)
186
-
187
- if not return_dict:
188
- return (lm_logits,) + transformer_outputs[1:]
189
-
190
- return CausalLMOutputWithCrossAttentions(
191
- loss=None,
192
- logits=lm_logits,
193
- past_key_values=transformer_outputs.past_key_values,
194
- hidden_states=transformer_outputs.hidden_states,
195
- attentions=transformer_outputs.attentions,
196
- cross_attentions=transformer_outputs.cross_attentions,
197
- )
198
-
199
- @staticmethod
200
- def _reorder_cache(past, beam_idx):
201
- """
202
- This function is used to re-order the :obj:`past_key_values` cache if
203
- :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
204
- called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
205
- """
206
- return tuple(
207
- tuple(
208
- past_state.index_select(0, beam_idx.to(past_state.device))
209
- for past_state in layer_past
210
- )
211
- for layer_past in past
212
- )
213
-
214
-
215
- class ConditioningEncoder(nn.Module):
216
- def __init__(self,
217
- spec_dim,
218
- embedding_dim,
219
- attn_blocks=6,
220
- num_attn_heads=4,
221
- do_checkpointing=False,
222
- mean=False):
223
- super().__init__()
224
- attn = []
225
- self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
226
- for a in range(attn_blocks):
227
- attn.append(AttentionBlock(embedding_dim, num_attn_heads))
228
- self.attn = nn.Sequential(*attn)
229
- self.dim = embedding_dim
230
- self.do_checkpointing = do_checkpointing
231
- self.mean = mean
232
-
233
- def forward(self, x):
234
- h = self.init(x)
235
- h = self.attn(h)
236
- if self.mean:
237
- return h.mean(dim=2)
238
- else:
239
- return h
240
- # return h[:, :, 0]
241
-
242
-
243
- class LearnedPositionEmbeddings(nn.Module):
244
- def __init__(self, seq_len, model_dim, init=.02):
245
- super().__init__()
246
- self.emb = nn.Embedding(seq_len, model_dim)
247
- # Initializing this way is standard for GPT-2
248
- self.emb.weight.data.normal_(mean=0.0, std=init)
249
-
250
- def forward(self, x):
251
- sl = x.shape[1]
252
- return self.emb(torch.arange(0, sl, device=x.device))
253
-
254
- def get_fixed_embedding(self, ind, dev):
255
- return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
256
-
257
-
258
- def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing, activation_function):
259
- """
260
- GPT-2 implemented by the HuggingFace library.
261
- """
262
- from transformers import GPT2Config, GPT2Model
263
- gpt_config = GPT2Config(vocab_size=256, # Unused.
264
- n_positions=max_mel_seq_len + max_text_seq_len,
265
- n_ctx=max_mel_seq_len + max_text_seq_len,
266
- n_embd=model_dim,
267
- n_layer=layers,
268
- n_head=heads,
269
- activation_function=activation_function or "gelu_new",
270
- gradient_checkpointing=checkpointing,
271
- use_cache=not checkpointing)
272
- gpt = GPT2Model(gpt_config)
273
- # Override the built in positional embeddings
274
- del gpt.wpe
275
- gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
276
- # Built-in token embeddings are unused.
277
- del gpt.wte
278
- return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim), \
279
- None, None
280
-
281
-
282
- class MelEncoder(nn.Module):
283
- def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
284
- super().__init__()
285
- self.channels = channels
286
- self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1),
287
- nn.Sequential(*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]),
288
- nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1),
289
- nn.GroupNorm(channels // 16, channels // 2),
290
- nn.ReLU(),
291
- nn.Sequential(*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]),
292
- nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1),
293
- nn.GroupNorm(channels // 8, channels),
294
- nn.ReLU(),
295
- nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
296
- )
297
- self.reduction = 4
298
-
299
- def forward(self, x):
300
- for e in self.encoder:
301
- x = e(x)
302
- return x.permute(0, 2, 1)
303
-
304
-
305
- class UnifiedVoice(nn.Module):
306
- def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
307
- mel_length_compression=1024, number_text_tokens=256,
308
- start_text_token=0, stop_text_token=1, number_mel_codes=8194, start_mel_token=8192, stop_mel_token=8193,
309
- train_solo_embeddings=False, use_mel_codes_as_input=True,
310
- checkpointing=True, types=1, activation_function=None,
311
- condition_num_latent=32, condition_type="perceiver", condition_module=None):
312
- """
313
- Args:
314
- layers: Number of layers in transformer stack.
315
- model_dim: Operating dimensions of the transformer
316
- heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
317
- max_text_tokens: Maximum number of text tokens that will be encountered by model.
318
- max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
319
- max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
320
- mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
321
- number_text_tokens:
322
- start_text_token:
323
- stop_text_token:
324
- number_mel_codes:
325
- start_mel_token:
326
- stop_mel_token:
327
- train_solo_embeddings:
328
- use_mel_codes_as_input:
329
- checkpointing:
330
- condition_type: perceiver, gst or default encoder
331
- """
332
- super().__init__()
333
- self.number_text_tokens = number_text_tokens
334
- self.start_text_token = start_text_token
335
- self.stop_text_token = stop_text_token
336
- self.number_mel_codes = number_mel_codes
337
- self.start_mel_token = start_mel_token
338
- self.stop_mel_token = stop_mel_token
339
- self.layers = layers
340
- self.heads = heads
341
- self.max_mel_tokens = max_mel_tokens
342
- self.max_text_tokens = max_text_tokens
343
- self.model_dim = model_dim
344
- self.max_conditioning_inputs = max_conditioning_inputs
345
- self.mel_length_compression = mel_length_compression
346
- self.condition_type = condition_type
347
- self.cond_num = condition_num_latent
348
- self.cond_mask_pad = nn.ConstantPad1d((self.cond_num, 0), True)
349
- if condition_type == "perceiver":
350
- self.conditioning_encoder = ConditioningEncoder(100, model_dim, num_attn_heads=heads)
351
- self.perceiver_encoder = PerceiverResampler(model_dim, dim_context=model_dim, num_latents=self.cond_num)
352
- elif condition_type == "conformer_perceiver" or condition_type == "conformer_encoder":
353
- self.conditioning_encoder = ConformerEncoder(input_size=100,
354
- output_size=condition_module['output_size'],
355
- linear_units=condition_module['linear_units'],
356
- attention_heads=condition_module['attention_heads'],
357
- num_blocks=condition_module['num_blocks'],
358
- input_layer=condition_module['input_layer'])
359
- if condition_type == "conformer_perceiver":
360
- self.perceiver_encoder = PerceiverResampler(model_dim, dim_context=condition_module['output_size'],
361
- ff_mult=condition_module['perceiver_mult'],
362
- heads=condition_module['attention_heads'],
363
- num_latents=self.cond_num)
364
- else:
365
- self.conditioning_encoder = ConditioningEncoder(100, model_dim, num_attn_heads=heads, mean=True)
366
-
367
- self.text_embedding = nn.Embedding(self.number_text_tokens * types + 1, model_dim)
368
- if use_mel_codes_as_input:
369
- self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
370
- else:
371
- self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
372
- self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
373
- build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens + 2 + self.max_conditioning_inputs,
374
- self.max_text_tokens + 2, checkpointing, activation_function)
375
- if train_solo_embeddings:
376
- self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
377
- self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
378
- else:
379
- self.mel_solo_embedding = 0
380
- self.text_solo_embedding = 0
381
-
382
- self.final_norm = nn.LayerNorm(model_dim)
383
- self.text_head = nn.Linear(model_dim, self.number_text_tokens * types + 1)
384
- self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
385
-
386
- # Initialize the embeddings per the GPT-2 scheme
387
- embeddings = [self.text_embedding]
388
- if use_mel_codes_as_input:
389
- embeddings.append(self.mel_embedding)
390
- for module in embeddings:
391
- module.weight.data.normal_(mean=0.0, std=.02)
392
-
393
- def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False, half=False):
394
- seq_length = self.max_mel_tokens + self.max_text_tokens + 2
395
- gpt_config = GPT2Config(
396
- vocab_size=self.number_mel_codes,
397
- n_positions=seq_length,
398
- n_ctx=seq_length,
399
- n_embd=self.model_dim,
400
- n_layer=self.layers,
401
- n_head=self.heads,
402
- gradient_checkpointing=False,
403
- use_cache=True,
404
- )
405
- self.inference_model = GPT2InferenceModel(
406
- gpt_config,
407
- self.gpt,
408
- self.mel_pos_embedding,
409
- self.mel_embedding,
410
- self.final_norm,
411
- self.mel_head,
412
- kv_cache=kv_cache,
413
- )
414
- if use_deepspeed and half and torch.cuda.is_available():
415
- import deepspeed
416
- self.ds_engine = deepspeed.init_inference(model=self.inference_model,
417
- mp_size=1,
418
- replace_with_kernel_inject=False,
419
- dtype=torch.float16)
420
- self.inference_model = self.ds_engine.module.eval()
421
- elif use_deepspeed and torch.cuda.is_available():
422
- import deepspeed
423
- self.ds_engine = deepspeed.init_inference(model=self.inference_model,
424
- mp_size=1,
425
- replace_with_kernel_inject=False,
426
- dtype=torch.float32)
427
- self.inference_model = self.ds_engine.module.eval()
428
- else:
429
- self.inference_model = self.inference_model.eval()
430
-
431
- # self.inference_model = PrunedGPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
432
- self.gpt.wte = self.mel_embedding
433
-
434
- def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
435
- inp = F.pad(input, (1, 0), value=start_token)
436
- tar = F.pad(input, (0, 1), value=stop_token)
437
- return inp, tar
438
-
439
- def set_mel_padding(self, mel_input_tokens, mel_lengths):
440
- """
441
- Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
442
- that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
443
- preformatting to create a working TTS model.
444
- """
445
- for b in range(len(mel_lengths)):
446
- # Due to the convolutional nature of how these tokens are generated,
447
- # it would be best if the model predicts a token past the actual last token.
448
- actual_end = mel_lengths[b]
449
- if actual_end < mel_input_tokens.shape[-1]:
450
- mel_input_tokens[b, actual_end:] = self.stop_mel_token
451
- return mel_input_tokens
452
-
453
- def set_text_padding(self, text_input_tokens, text_lengths):
454
- """
455
- Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
456
- that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
457
- preformatting to create a working TTS model.
458
- """
459
- for b in range(len(text_lengths)):
460
- # Due to the convolutional nature of how these tokens are generated,
461
- # it would be best if the model predicts a token past the actual last token.
462
- actual_end = text_lengths[b]
463
- if actual_end < text_input_tokens.shape[-1]:
464
- text_input_tokens[b, actual_end:] = self.stop_text_token
465
- return text_input_tokens
466
-
467
- def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False, return_latent=False):
468
- if second_inputs is not None:
469
- emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
470
- else:
471
- emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
472
-
473
- gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
474
- if get_attns:
475
- return gpt_out.attentions
476
-
477
- offset = speech_conditioning_inputs.shape[1]
478
- enc = gpt_out.last_hidden_state[:, offset:]
479
- enc = self.final_norm(enc)
480
-
481
- if return_latent:
482
- return enc[:, :first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:]
483
-
484
- first_logits = enc[:, :first_inputs.shape[1]]
485
- first_logits = first_head(first_logits)
486
- first_logits = first_logits.permute(0, 2, 1)
487
- if second_inputs is not None:
488
- second_logits = enc[:, -second_inputs.shape[1]:]
489
- second_logits = second_head(second_logits)
490
- second_logits = second_logits.permute(0, 2, 1)
491
- return first_logits, second_logits
492
- else:
493
- return first_logits
494
-
495
- def get_conditioning(self, speech_conditioning_input, cond_mel_lengths=None):
496
- if self.condition_type == "perceiver":
497
- if speech_conditioning_input.ndim == 4:
498
- speech_conditioning_input = speech_conditioning_input.squeeze(1)
499
- speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input) # (b, d, s)
500
- conds = self.perceiver_encoder(speech_conditioning_input.transpose(1, 2)) # (b, 32, d)
501
- elif self.condition_type == "conformer_perceiver":
502
- speech_conditioning_input, mask = self.conditioning_encoder(speech_conditioning_input.transpose(1, 2),
503
- cond_mel_lengths) # (b, s, d), (b, 1, s)
504
- if self.condition_type == "conformer_perceiver":
505
- # conds_mask = torch.cat([torch.ones((mask.shape[0], self.cond_num), dtype=torch.bool), mask.squeeze(1)], dim=1)
506
- conds_mask = self.cond_mask_pad(mask.squeeze(1))
507
- conds = self.perceiver_encoder(speech_conditioning_input, conds_mask) # (b, 32, d)
508
- elif self.condition_type == "gst":
509
- if speech_conditioning_input.ndim == 4:
510
- speech_conditioning_input = speech_conditioning_input.squeeze(1)
511
- conds = self.gst_encoder(speech_conditioning_input.transpose(1, 2)) # (b, 1, d)
512
- else:
513
- speech_conditioning_input = (
514
- speech_conditioning_input.unsqueeze(1)
515
- if len(speech_conditioning_input.shape) == 3
516
- else speech_conditioning_input
517
- )
518
- conds = []
519
- for j in range(speech_conditioning_input.shape[1]):
520
- conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
521
- conds = torch.stack(conds, dim=1)
522
- conds = conds.mean(dim=1)
523
- conds = conds.unsqueeze(1)
524
- return conds
525
-
526
- def forward(self, speech_conditioning_latent, text_inputs, text_lengths, mel_codes, wav_lengths,
527
- cond_mel_lengths=None, types=None, text_first=True, raw_mels=None, return_attentions=False,
528
- return_latent=False, clip_inputs=False):
529
- """
530
- Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
531
- (actuated by `text_first`).
532
-
533
- speech_conditioning_input: MEL float tensor, (b,1024)
534
- text_inputs: long tensor, (b,t)
535
- text_lengths: long tensor, (b,)
536
- mel_inputs: long tensor, (b,m)
537
- wav_lengths: long tensor, (b,)
538
- raw_mels: MEL float tensor (b,80,s)
539
-
540
- If return_attentions is specified, only logits are returned.
541
- If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
542
- If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
543
- """
544
-
545
- speech_conditioning_latent = self.get_conditioning(speech_conditioning_latent, cond_mel_lengths)
546
- # Types are expressed by expanding the text embedding space.
547
- if types is not None:
548
- text_inputs = text_inputs * (1 + types).unsqueeze(-1)
549
-
550
- if clip_inputs:
551
- # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
552
- # chopping the inputs by the maximum actual length.
553
- max_text_len = text_lengths.max()
554
- text_inputs = text_inputs[:, :max_text_len]
555
- max_mel_len = wav_lengths.max() // self.mel_length_compression
556
- mel_codes = mel_codes[:, :max_mel_len]
557
- if raw_mels is not None:
558
- raw_mels = raw_mels[:, :, :max_mel_len * 4]
559
-
560
- # Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
561
- # mel_codes_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode='trunc')
562
- mel_codes_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 1
563
- mel_codes = self.set_mel_padding(mel_codes, mel_codes_lengths)
564
- text_inputs = self.set_text_padding(text_inputs, text_lengths)
565
- text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
566
- mel_codes = F.pad(mel_codes, (0, 1), value=self.stop_mel_token)
567
-
568
- conds = speech_conditioning_latent
569
- text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
570
- text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
571
- mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
572
- if raw_mels is not None:
573
- mel_inp = F.pad(raw_mels, (0, 8))
574
- else:
575
- mel_inp = mel_codes
576
- mel_emb = self.mel_embedding(mel_inp)
577
- mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
578
-
579
- if text_first:
580
- # print(f"conds: {conds.shape}, text_emb: {text_emb.shape}, mel_emb: {mel_emb.shape}")
581
- text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent)
582
- if return_latent:
583
- return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
584
- else:
585
- mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent)
586
- if return_latent:
587
- return text_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
588
-
589
- if return_attentions:
590
- return mel_logits
591
-
592
- loss_text = F.cross_entropy(text_logits, text_targets.long())
593
- loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
594
- return loss_text.mean(), loss_mel.mean(), mel_logits
595
-
596
- def prepare_gpt_inputs(
597
- self,
598
- conditional_latents: torch.Tensor,
599
- text_inputs: torch.Tensor,
600
- ):
601
-
602
- """
603
- Prepare the inputs for the GPT2InferenceModel to generate.
604
- Args:
605
- conds_latent: (b, 32, dim) audio conditioning embedding by `get_conditioning()`
606
- text_inputs: (b, L)
607
- Returns:
608
- input_ids: (b, s+1) the input ids for the GPT2InferenceModel.generate()
609
- inputs_embeds: (b, s+1, dim) the input embeddings for the GPT2InferenceModel.forward()
610
- attention_mask: (b, s+1) the attention mask for the GPT2InferenceModel.generate()
611
- """
612
- b, L = text_inputs.shape[:2]
613
- device = text_inputs.device
614
- single_cond = conditional_latents.ndim == 3 and conditional_latents.shape[0] == 1
615
- if not single_cond:
616
- assert conditional_latents.shape[0] == b, f"batch size mismatch: {conditional_latents.shape[0]} vs {b}"
617
- batched_mel_emb = []
618
- attention_masks = []
619
- target_len = conditional_latents.shape[1] + L + 2
620
- for i in range(b):
621
- valid_mask = (text_inputs[i] != self.stop_text_token) & (text_inputs[i] != self.start_text_token)
622
- text_input = text_inputs[i][valid_mask]
623
- text_input = F.pad(text_input, (1, 0), value=self.start_text_token)
624
- text_input = F.pad(text_input, (0, 1), value=self.stop_text_token)
625
- text_input_pos = torch.arange(0, text_input.size(-1), device=device)
626
- text_emb = self.text_embedding(text_input) + self.text_pos_embedding.emb(text_input_pos)
627
- # concatenate [conditional latents][text embeddings]
628
- conds_text_emb = [
629
- conditional_latents.squeeze(0) if single_cond else conditional_latents[i],
630
- text_emb,
631
- ]
632
- # +1 for the start_mel_token
633
- attention_mask = torch.ones(target_len+1, dtype=torch.long, device=device)
634
- # check this text input is padded
635
- padding: int = L + 2 - text_input.size(-1)
636
- # pad left of [cond][text] -> [pad][cond][text]
637
- if padding > 0:
638
- pad = torch.zeros((padding, conditional_latents.size(-1)), dtype=text_emb.dtype, device=device) # [p, dim]
639
- conds_text_emb.insert(0, pad)
640
- attention_mask[:padding] = 0
641
- mel_emb = torch.cat(conds_text_emb) #[s, dim]
642
- assert mel_emb.shape[0] == target_len, f"mel_emb.shape: {mel_emb.shape}, target_len: {target_len}"
643
- batched_mel_emb.append(mel_emb)
644
- attention_masks.append(attention_mask)
645
- # [b, s, dim]
646
- batched_mel_emb = torch.stack(batched_mel_emb, dim=0)
647
- # [b, s+1]
648
- attention_mask = torch.stack(attention_masks, dim=0)
649
- # [b, s+1]
650
- fake_inputs = torch.ones(
651
- (
652
- batched_mel_emb.shape[0],
653
- batched_mel_emb.shape[1] + 1, # +1 for the start_mel_token
654
- ),
655
- dtype=torch.long,
656
- device=device,
657
- )
658
- fake_inputs[:, -1] = self.start_mel_token
659
- return fake_inputs, batched_mel_emb, attention_mask
660
- def inference_speech(self, speech_conditioning_mel, text_inputs, cond_mel_lengths=None, input_tokens=None, num_return_sequences=1,
661
- max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
662
- """
663
- Args:
664
- speech_conditioning_mel: (b, n_mels, frames) or (n_mels, frames)
665
- text_inputs: (b, L)
666
- cond_mel_lengths: lengths of the conditioning mel spectrograms in shape (b,) or (1,)
667
- input_tokens: additional tokens for generation in shape (b, s) or (s,)
668
- max_generate_length: limit the number of generated tokens
669
- hf_generate_kwargs: kwargs for `GPT2InferenceModel.generate(**hf_generate_kwargs)`
670
- """
671
- if speech_conditioning_mel.ndim == 2:
672
- speech_conditioning_mel = speech_conditioning_mel.unsqueeze(0)
673
- if cond_mel_lengths is None:
674
- cond_mel_lengths = torch.tensor([speech_conditioning_mel.shape[-1]], device=speech_conditioning_mel.device)
675
- conds_latent = self.get_conditioning(speech_conditioning_mel, cond_mel_lengths)
676
- input_ids, inputs_embeds, attention_mask = self.prepare_gpt_inputs(conds_latent, text_inputs)
677
- self.inference_model.store_mel_emb(inputs_embeds)
678
- if input_tokens is None:
679
- inputs = input_ids
680
- else:
681
- if input_tokens.ndim == 1:
682
- input_tokens = input_tokens.unsqueeze(0)
683
- assert num_return_sequences % input_tokens.shape[0] == 0, \
684
- "The num_return_sequences must be divisible by the batch number of input_tokens"
685
- assert num_return_sequences % text_inputs.shape[0] == 0, \
686
- "The num_return_sequences must be divisible by the batch number of text_inputs"
687
- b = num_return_sequences // input_ids.shape[0]
688
- if b > 1:
689
- input_ids = input_ids.repeat(b, 1)
690
- attention_mask = attention_mask.repeat(b, 1)
691
- input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1)
692
- inputs = torch.cat([input_ids, input_tokens], dim=1)
693
- attention_mask = F.pad(attention_mask, (0, input_tokens.shape[1]), value=1)
694
- trunc_index = inputs.shape[1]
695
- logits_processor = LogitsProcessorList()
696
- if typical_sampling:
697
- # employ custom typical sampling
698
- if not (typical_mass > 0.0 and typical_mass < 1.0):
699
- raise ValueError(f"`typical_mass` has to be a float > 0 and < 1, but is {typical_mass}")
700
- min_tokens_to_keep = 2 if hf_generate_kwargs.get("num_beams", 1) > 1 else 1
701
- logits_processor.append(TypicalLogitsWarper(mass=typical_mass, min_tokens_to_keep=min_tokens_to_keep))
702
- max_length = (trunc_index + self.max_mel_tokens - 1) if max_generate_length is None else trunc_index + max_generate_length
703
- output = self.inference_model.generate(inputs,
704
- bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token,
705
- eos_token_id=self.stop_mel_token, attention_mask=attention_mask,
706
- max_length=max_length, logits_processor=logits_processor,
707
- num_return_sequences=num_return_sequences,
708
- **hf_generate_kwargs)
709
- if isinstance(output, torch.Tensor):
710
- return output[:, trunc_index:]
711
- # GenerateOutput
712
- output.sequences = output.sequences[:, trunc_index:]
713
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/gpt/model_v2.py DELETED
@@ -1,747 +0,0 @@
1
- import functools
2
-
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
-
7
- import transformers
8
- from transformers import GPT2Config, LogitsProcessorList
9
- from indextts.gpt.transformers_gpt2 import GPT2PreTrainedModel, GPT2Model
10
-
11
- # from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
12
- from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
13
- from transformers.utils.model_parallel_utils import (assert_device_map,
14
- get_device_map)
15
-
16
- from indextts.gpt.conformer_encoder import ConformerEncoder
17
- from indextts.gpt.perceiver import PerceiverResampler
18
- from indextts.utils.arch_util import AttentionBlock
19
- from indextts.utils.typical_sampling import TypicalLogitsWarper
20
-
21
-
22
- def null_position_embeddings(range, dim):
23
- return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
24
-
25
-
26
- class ResBlock(nn.Module):
27
- """
28
- Basic residual convolutional block that uses GroupNorm.
29
- """
30
-
31
- def __init__(self, chan):
32
- super().__init__()
33
- self.net = nn.Sequential(
34
- nn.Conv1d(chan, chan, kernel_size=3, padding=1),
35
- nn.GroupNorm(chan // 8, chan),
36
- nn.ReLU(),
37
- nn.Conv1d(chan, chan, kernel_size=3, padding=1),
38
- nn.GroupNorm(chan // 8, chan)
39
- )
40
-
41
- def forward(self, x):
42
- return F.relu(self.net(x) + x)
43
-
44
-
45
- class GPT2InferenceModel(GPT2PreTrainedModel):
46
- def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache=False):
47
- super().__init__(config)
48
- # Note: the argument named `text_pos_emb` here actually represents the mel position embedding
49
- self.transformer = gpt
50
- self.text_pos_embedding = text_pos_emb
51
- self.embeddings = embeddings
52
- self.final_norm = norm
53
- self.lm_head = nn.Sequential(norm, linear)
54
- self.kv_cache = kv_cache
55
-
56
- # Model parallel
57
- self.model_parallel = False
58
- self.device_map = None
59
- self.cached_mel_emb = None
60
-
61
- def parallelize(self, device_map=None):
62
- self.device_map = (
63
- get_device_map(len(self.transformer.h), range(max(1, torch.cuda.device_count())))
64
- if device_map is None
65
- else device_map
66
- )
67
- assert_device_map(self.device_map, len(self.transformer.h))
68
- self.transformer.parallelize(self.device_map)
69
- self.lm_head = self.lm_head.to(self.transformer.first_device)
70
- self.model_parallel = True
71
-
72
- def deparallelize(self):
73
- self.transformer.deparallelize()
74
- self.transformer = self.transformer.to("cpu")
75
- self.lm_head = self.lm_head.to("cpu")
76
- self.model_parallel = False
77
- torch.cuda.empty_cache()
78
- if torch.backends.mps.is_available():
79
- torch.mps.empty_cache()
80
-
81
- def get_output_embeddings(self):
82
- return self.lm_head
83
-
84
- def set_output_embeddings(self, new_embeddings):
85
- self.lm_head = new_embeddings
86
-
87
- def store_mel_emb(self, mel_emb):
88
- self.cached_mel_emb = mel_emb
89
-
90
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
91
- token_type_ids = kwargs.get("token_type_ids", None) # usually None
92
- if not self.kv_cache:
93
- past_key_values = None
94
- # only last token for inputs_ids if past is defined in kwargs
95
- if past_key_values:
96
- input_ids = input_ids[:, -1].unsqueeze(-1)
97
- if token_type_ids is not None:
98
- token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
99
-
100
- attention_mask = kwargs.get("attention_mask", None)
101
- position_ids = kwargs.get("position_ids", None)
102
-
103
- if attention_mask is not None and position_ids is None:
104
- # create position_ids on the fly for batch generation
105
- position_ids = attention_mask.long().cumsum(-1) - 1
106
- position_ids.masked_fill_(attention_mask == 0, 0)
107
- if past_key_values:
108
- position_ids = position_ids[:, -1].unsqueeze(-1)
109
- else:
110
- position_ids = None
111
- return {
112
- "input_ids": input_ids,
113
- "past_key_values": past_key_values,
114
- "use_cache": kwargs.get("use_cache"),
115
- "position_ids": position_ids,
116
- "attention_mask": attention_mask,
117
- "token_type_ids": token_type_ids,
118
- }
119
-
120
- def forward(
121
- self,
122
- input_ids=None,
123
- past_key_values=None,
124
- attention_mask=None,
125
- token_type_ids=None,
126
- position_ids=None,
127
- head_mask=None,
128
- inputs_embeds=None,
129
- encoder_hidden_states=None,
130
- encoder_attention_mask=None,
131
- labels=None,
132
- use_cache=None,
133
- output_attentions=None,
134
- output_hidden_states=None,
135
- return_dict=None,
136
- ):
137
- assert self.cached_mel_emb is not None
138
- assert inputs_embeds is None # Not supported by this inference model.
139
- assert labels is None # Training not supported by this inference model.
140
- return_dict = (
141
- return_dict if return_dict is not None else self.config.use_return_dict
142
- )
143
- # Create embedding
144
- mel_len = self.cached_mel_emb.shape[1]
145
- if input_ids.shape[1] != 1:
146
- text_inputs = input_ids[:, mel_len:]
147
- text_emb = self.embeddings(text_inputs)
148
- text_emb = text_emb + self.text_pos_embedding(text_emb)
149
- if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
150
- mel_emb = self.cached_mel_emb.repeat_interleave(
151
- text_emb.shape[0] // self.cached_mel_emb.shape[0], 0
152
- )
153
- else: # this outcome only occurs once per loop in most cases
154
- mel_emb = self.cached_mel_emb
155
- emb = torch.cat([mel_emb, text_emb], dim=1)
156
- else:
157
- emb = self.embeddings(input_ids)
158
- emb = emb + self.text_pos_embedding.get_fixed_embedding(
159
- attention_mask.shape[1] - mel_len, attention_mask.device
160
- )
161
- transformer_outputs = self.transformer(
162
- inputs_embeds=emb,
163
- past_key_values=past_key_values,
164
- attention_mask=attention_mask,
165
- token_type_ids=token_type_ids,
166
- position_ids=position_ids,
167
- head_mask=head_mask,
168
- encoder_hidden_states=encoder_hidden_states,
169
- encoder_attention_mask=encoder_attention_mask,
170
- use_cache=use_cache,
171
- output_attentions=output_attentions,
172
- output_hidden_states=output_hidden_states,
173
- return_dict=return_dict,
174
- )
175
- hidden_states = transformer_outputs[0]
176
-
177
- # Set device for model parallelism
178
- if self.model_parallel:
179
- if torch.backends.mps.is_available():
180
- self.to(self.transformer.first_device)
181
- else:
182
- torch.cuda.set_device(self.transformer.first_device)
183
- hidden_states = hidden_states.to(self.lm_head.weight.device)
184
-
185
- lm_logits = self.lm_head(hidden_states)
186
-
187
- if not return_dict:
188
- return (lm_logits,) + transformer_outputs[1:]
189
-
190
- return CausalLMOutputWithCrossAttentions(
191
- loss=None,
192
- logits=lm_logits,
193
- past_key_values=transformer_outputs.past_key_values,
194
- hidden_states=transformer_outputs.hidden_states,
195
- attentions=transformer_outputs.attentions,
196
- cross_attentions=transformer_outputs.cross_attentions,
197
- )
198
-
199
- @staticmethod
200
- def _reorder_cache(past, beam_idx):
201
- """
202
- This function is used to re-order the :obj:`past_key_values` cache if
203
- :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
204
- called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
205
- """
206
- return tuple(
207
- tuple(
208
- past_state.index_select(0, beam_idx.to(past_state.device))
209
- for past_state in layer_past
210
- )
211
- for layer_past in past
212
- )
213
-
214
-
215
- class ConditioningEncoder(nn.Module):
216
- def __init__(self,
217
- spec_dim,
218
- embedding_dim,
219
- attn_blocks=6,
220
- num_attn_heads=4,
221
- do_checkpointing=False,
222
- mean=False):
223
- super().__init__()
224
- attn = []
225
- self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
226
- for a in range(attn_blocks):
227
- attn.append(AttentionBlock(embedding_dim, num_attn_heads))
228
- self.attn = nn.Sequential(*attn)
229
- self.dim = embedding_dim
230
- self.do_checkpointing = do_checkpointing
231
- self.mean = mean
232
-
233
- def forward(self, x):
234
- h = self.init(x)
235
- h = self.attn(h)
236
- if self.mean:
237
- return h.mean(dim=2)
238
- else:
239
- return h
240
- # return h[:, :, 0]
241
-
242
-
243
- class LearnedPositionEmbeddings(nn.Module):
244
- def __init__(self, seq_len, model_dim, init=.02):
245
- super().__init__()
246
- self.emb = nn.Embedding(seq_len, model_dim)
247
- # Initializing this way is standard for GPT-2
248
- self.emb.weight.data.normal_(mean=0.0, std=init)
249
-
250
- def forward(self, x):
251
- sl = x.shape[1]
252
- return self.emb(torch.arange(0, sl, device=x.device))
253
-
254
- def get_fixed_embedding(self, ind, dev):
255
- return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
256
-
257
-
258
- def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
259
- """
260
- GPT-2 implemented by the HuggingFace library.
261
- """
262
- from transformers import GPT2Config, GPT2Model
263
- gpt_config = GPT2Config(vocab_size=256, # Unused.
264
- n_positions=max_mel_seq_len + max_text_seq_len,
265
- n_ctx=max_mel_seq_len + max_text_seq_len,
266
- n_embd=model_dim,
267
- n_layer=layers,
268
- n_head=heads,
269
- gradient_checkpointing=checkpointing,
270
- use_cache=not checkpointing)
271
- gpt = GPT2Model(gpt_config)
272
- # Override the built in positional embeddings
273
- del gpt.wpe
274
- gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
275
- # Built-in token embeddings are unused.
276
- del gpt.wte
277
- return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim), \
278
- None, None
279
-
280
-
281
- class MelEncoder(nn.Module):
282
- def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
283
- super().__init__()
284
- self.channels = channels
285
- self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1),
286
- nn.Sequential(*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]),
287
- nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1),
288
- nn.GroupNorm(channels // 16, channels // 2),
289
- nn.ReLU(),
290
- nn.Sequential(*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]),
291
- nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1),
292
- nn.GroupNorm(channels // 8, channels),
293
- nn.ReLU(),
294
- nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
295
- )
296
- self.reduction = 4
297
-
298
- def forward(self, x):
299
- for e in self.encoder:
300
- x = e(x)
301
- return x.permute(0, 2, 1)
302
-
303
-
304
- class UnifiedVoice(nn.Module):
305
- def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
306
- mel_length_compression=1024, number_text_tokens=256,
307
- start_text_token=0, stop_text_token=1, number_mel_codes=8194, start_mel_token=8192, stop_mel_token=8193,
308
- train_solo_embeddings=False, use_mel_codes_as_input=True,
309
- checkpointing=True, types=1,
310
- condition_num_latent=32, condition_type="perceiver", condition_module=None, emo_condition_module=None):
311
- """
312
- Args:
313
- layers: Number of layers in transformer stack.
314
- model_dim: Operating dimensions of the transformer
315
- heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
316
- max_text_tokens: Maximum number of text tokens that will be encountered by model.
317
- max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
318
- max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
319
- mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
320
- number_text_tokens:
321
- start_text_token:
322
- stop_text_token:
323
- number_mel_codes:
324
- start_mel_token:
325
- stop_mel_token:
326
- train_solo_embeddings:
327
- use_mel_codes_as_input:
328
- checkpointing:
329
- condition_type: perceiver, gst or default encoder
330
- """
331
- super().__init__()
332
- self.number_text_tokens = number_text_tokens
333
- self.start_text_token = start_text_token
334
- self.stop_text_token = stop_text_token
335
- self.number_mel_codes = number_mel_codes
336
- self.start_mel_token = start_mel_token
337
- self.stop_mel_token = stop_mel_token
338
- self.layers = layers
339
- self.heads = heads
340
- self.max_mel_tokens = max_mel_tokens
341
- self.max_text_tokens = max_text_tokens
342
- self.model_dim = model_dim
343
- self.max_conditioning_inputs = max_conditioning_inputs
344
- self.mel_length_compression = mel_length_compression
345
- self.condition_type = condition_type
346
- self.cond_num = condition_num_latent
347
- self.cond_mask_pad = nn.ConstantPad1d((self.cond_num, 0), True)
348
- self.emo_cond_mask_pad = nn.ConstantPad1d((1, 0), True)
349
- if condition_type == "perceiver":
350
- self.conditioning_encoder = ConditioningEncoder(1024, model_dim, num_attn_heads=heads)
351
- self.perceiver_encoder = PerceiverResampler(model_dim, dim_context=model_dim, num_latents=self.cond_num)
352
- elif condition_type == "conformer_perceiver" or condition_type == "conformer_encoder":
353
- self.conditioning_encoder = ConformerEncoder(input_size=1024,
354
- output_size=condition_module['output_size'],
355
- linear_units=condition_module['linear_units'],
356
- attention_heads=condition_module['attention_heads'],
357
- num_blocks=condition_module['num_blocks'],
358
- input_layer=condition_module['input_layer'])
359
- if condition_type == "conformer_perceiver":
360
- self.perceiver_encoder = PerceiverResampler(model_dim, dim_context=condition_module['output_size'],
361
- ff_mult=condition_module['perceiver_mult'],
362
- heads=condition_module['attention_heads'],
363
- num_latents=self.cond_num)
364
- else:
365
- self.conditioning_encoder = ConditioningEncoder(1024, model_dim, num_attn_heads=heads, mean=True)
366
-
367
- self.emo_conditioning_encoder = ConformerEncoder(input_size=1024,
368
- output_size=emo_condition_module['output_size'],
369
- linear_units=emo_condition_module['linear_units'],
370
- attention_heads=emo_condition_module['attention_heads'],
371
- num_blocks=emo_condition_module['num_blocks'],
372
- input_layer=emo_condition_module['input_layer'])
373
- self.emo_perceiver_encoder = PerceiverResampler(1024, dim_context=emo_condition_module['output_size'],
374
- ff_mult=emo_condition_module['perceiver_mult'],
375
- heads=emo_condition_module['attention_heads'],
376
- num_latents=1)
377
-
378
-
379
-
380
- self.text_embedding = nn.Embedding(self.number_text_tokens * types + 1, model_dim)
381
- self.emo_layer = nn.Linear(model_dim, model_dim)
382
- self.emovec_layer = nn.Linear(1024, model_dim)
383
-
384
- if use_mel_codes_as_input:
385
- self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
386
- else:
387
- self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
388
- self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
389
- build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens + 2 + self.max_conditioning_inputs,
390
- self.max_text_tokens + 2, checkpointing)
391
- if train_solo_embeddings:
392
- self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
393
- self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
394
- else:
395
- self.mel_solo_embedding = 0
396
- self.text_solo_embedding = 0
397
-
398
- self.final_norm = nn.LayerNorm(model_dim)
399
- self.text_head = nn.Linear(model_dim, self.number_text_tokens * types + 1)
400
- self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
401
-
402
- self.speed_emb = nn.Embedding(2, model_dim)
403
- self.speed_emb.weight.data.normal_(mean=0.0, std=0.0)
404
-
405
- # Initialize the embeddings per the GPT-2 scheme
406
- embeddings = [self.text_embedding]
407
- if use_mel_codes_as_input:
408
- embeddings.append(self.mel_embedding)
409
- for module in embeddings:
410
- module.weight.data.normal_(mean=0.0, std=.02)
411
-
412
- def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False, half=False):
413
- seq_length = self.max_mel_tokens + self.max_text_tokens + 2
414
- gpt_config = GPT2Config(
415
- vocab_size=self.number_mel_codes,
416
- n_positions=seq_length,
417
- n_ctx=seq_length,
418
- n_embd=self.model_dim,
419
- n_layer=self.layers,
420
- n_head=self.heads,
421
- gradient_checkpointing=False,
422
- use_cache=True,
423
- )
424
- self.inference_model = GPT2InferenceModel(
425
- gpt_config,
426
- self.gpt,
427
- self.mel_pos_embedding,
428
- self.mel_embedding,
429
- self.final_norm,
430
- self.mel_head,
431
- kv_cache=kv_cache,
432
- )
433
- if use_deepspeed and half and torch.cuda.is_available():
434
- import deepspeed
435
- self.ds_engine = deepspeed.init_inference(model=self.inference_model,
436
- mp_size=1,
437
- replace_with_kernel_inject=True,
438
- dtype=torch.float16)
439
- self.inference_model = self.ds_engine.module.eval()
440
- elif use_deepspeed and torch.cuda.is_available():
441
- import deepspeed
442
- self.ds_engine = deepspeed.init_inference(model=self.inference_model,
443
- mp_size=1,
444
- replace_with_kernel_inject=True,
445
- dtype=torch.float32)
446
- self.inference_model = self.ds_engine.module.eval()
447
- else:
448
- self.inference_model = self.inference_model.eval()
449
-
450
- # self.inference_model = PrunedGPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
451
- self.gpt.wte = self.mel_embedding
452
-
453
- def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
454
- inp = F.pad(input, (1, 0), value=start_token)
455
- tar = F.pad(input, (0, 1), value=stop_token)
456
- return inp, tar
457
-
458
- def set_mel_padding(self, mel_input_tokens, mel_lengths):
459
- """
460
- Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
461
- that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
462
- preformatting to create a working TTS model.
463
- """
464
- for b in range(len(mel_lengths)):
465
- # Due to the convolutional nature of how these tokens are generated,
466
- # it would be best if the model predicts a token past the actual last token.
467
- actual_end = mel_lengths[b]
468
- if actual_end < mel_input_tokens.shape[-1]:
469
- mel_input_tokens[b, actual_end:] = self.stop_mel_token
470
- return mel_input_tokens
471
-
472
- def set_text_padding(self, text_input_tokens, text_lengths):
473
- """
474
- Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
475
- that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
476
- preformatting to create a working TTS model.
477
- """
478
- for b in range(len(text_lengths)):
479
- # Due to the convolutional nature of how these tokens are generated,
480
- # it would be best if the model predicts a token past the actual last token.
481
- actual_end = text_lengths[b]
482
- if actual_end < text_input_tokens.shape[-1]:
483
- text_input_tokens[b, actual_end:] = self.stop_text_token
484
- return text_input_tokens
485
-
486
- def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False, return_latent=False):
487
- if second_inputs is not None:
488
- emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
489
- else:
490
- emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
491
-
492
- gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
493
- if get_attns:
494
- return gpt_out.attentions
495
-
496
- offset = speech_conditioning_inputs.shape[1]
497
- enc = gpt_out.last_hidden_state[:, offset:]
498
- enc = self.final_norm(enc)
499
-
500
- if return_latent:
501
- return enc[:, :first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:]
502
-
503
- first_logits = enc[:, :first_inputs.shape[1]]
504
- first_logits = first_head(first_logits)
505
- first_logits = first_logits.permute(0, 2, 1)
506
- if second_inputs is not None:
507
- second_logits = enc[:, -second_inputs.shape[1]:]
508
- second_logits = second_head(second_logits)
509
- second_logits = second_logits.permute(0, 2, 1)
510
- return first_logits, second_logits
511
- else:
512
- return first_logits
513
-
514
- def get_conditioning(self, speech_conditioning_input, cond_mel_lengths=None):
515
- if self.condition_type == "perceiver":
516
- if speech_conditioning_input.ndim == 4:
517
- speech_conditioning_input = speech_conditioning_input.squeeze(1)
518
- speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input) # (b, d, s)
519
- conds = self.perceiver_encoder(speech_conditioning_input.transpose(1, 2)) # (b, 32, d)
520
- elif self.condition_type == "conformer_perceiver":
521
- speech_conditioning_input, mask = self.conditioning_encoder(speech_conditioning_input.transpose(1, 2),
522
- cond_mel_lengths) # (b, s, d), (b, 1, s)
523
- if self.condition_type == "conformer_perceiver":
524
- # conds_mask = torch.cat([torch.ones((mask.shape[0], self.cond_num), dtype=torch.bool), mask.squeeze(1)], dim=1)
525
- conds_mask = self.cond_mask_pad(mask.squeeze(1))
526
- conds = self.perceiver_encoder(speech_conditioning_input, conds_mask) # (b, 32, d)
527
- elif self.condition_type == "gst":
528
- if speech_conditioning_input.ndim == 4:
529
- speech_conditioning_input = speech_conditioning_input.squeeze(1)
530
- conds = self.gst_encoder(speech_conditioning_input.transpose(1, 2)) # (b, 1, d)
531
- else:
532
- speech_conditioning_input = (
533
- speech_conditioning_input.unsqueeze(1)
534
- if len(speech_conditioning_input.shape) == 3
535
- else speech_conditioning_input
536
- )
537
- conds = []
538
- for j in range(speech_conditioning_input.shape[1]):
539
- conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
540
- conds = torch.stack(conds, dim=1)
541
- conds = conds.mean(dim=1)
542
- conds = conds.unsqueeze(1)
543
- return conds
544
-
545
-
546
- def get_emo_conditioning(self, speech_conditioning_input, cond_mel_lengths=None):
547
- speech_conditioning_input, mask = self.emo_conditioning_encoder(speech_conditioning_input.transpose(1, 2),
548
- cond_mel_lengths) # (b, s, d), (b, 1, s)
549
- conds_mask = self.emo_cond_mask_pad(mask.squeeze(1))
550
- conds = self.emo_perceiver_encoder(speech_conditioning_input, conds_mask) # (b, 1, d)
551
- return conds.squeeze(1)
552
-
553
-
554
- def forward(self, speech_conditioning_latent, text_inputs, text_lengths, mel_codes, mel_codes_lengths, emo_speech_conditioning_latent,
555
- cond_mel_lengths=None, emo_cond_mel_lengths=None, emo_vec=None, use_speed=None, do_spk_cond=False):
556
- """
557
- Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
558
-
559
- speech_conditioning_input: MEL float tensor, (b,1024)
560
- text_inputs: long tensor, (b,t)
561
- text_lengths: long tensor, (b,)
562
- mel_inputs: long tensor, (b,m)
563
- wav_lengths: long tensor, (b,)
564
-
565
- If return_attentions is specified, only logits are returned.
566
- If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
567
- """
568
-
569
- if do_spk_cond:
570
- speech_conditioning_latent = self.get_conditioning(speech_conditioning_latent.transpose(1,2), cond_mel_lengths)
571
- else:
572
- speech_conditioning_latent = speech_conditioning_latent
573
-
574
- if emo_vec is None:
575
- emo_vec_syn_ori = self.get_emo_conditioning(emo_speech_conditioning_latent.transpose(1,2), emo_cond_mel_lengths)
576
- emo_vec_syn = self.emovec_layer(emo_vec_syn_ori)
577
- emo_vec = self.emo_layer(emo_vec_syn)
578
-
579
- text_inputs = self.set_text_padding(text_inputs, text_lengths)
580
- text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
581
-
582
- mel_codes = self.set_mel_padding(mel_codes, mel_codes_lengths)
583
- mel_codes = F.pad(mel_codes, (0, 1), value=self.stop_mel_token)
584
-
585
- duration_emb = self.speed_emb(torch.zeros_like(use_speed))
586
- duration_emb_half = self.speed_emb(torch.ones_like(use_speed))
587
- conds = torch.cat((speech_conditioning_latent + emo_vec.unsqueeze(1), duration_emb_half.unsqueeze(1), duration_emb.unsqueeze(1)), 1)
588
- text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
589
- text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
590
- mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
591
-
592
- mel_emb = self.mel_embedding(mel_codes)
593
- mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
594
-
595
- text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=False, return_latent=True)
596
- return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
597
-
598
- def prepare_gpt_inputs(
599
- self,
600
- conditional_latents: torch.Tensor,
601
- text_inputs: torch.Tensor,
602
- ):
603
-
604
- """
605
- Prepare the inputs for the GPT2InferenceModel to generate.
606
- Args:
607
- conds_latent: (b, 32, dim) audio conditioning embedding by `get_conditioning()`
608
- text_inputs: (b, L)
609
- Returns:
610
- input_ids: (b, s+1) the input ids for the GPT2InferenceModel.generate()
611
- inputs_embeds: (b, s+1, dim) the input embeddings for the GPT2InferenceModel.forward()
612
- attention_mask: (b, s+1) the attention mask for the GPT2InferenceModel.generate()
613
- """
614
- b, L = text_inputs.shape[:2]
615
- device = text_inputs.device
616
- single_cond = conditional_latents.ndim == 3 and conditional_latents.shape[0] == 1
617
- if not single_cond:
618
- assert conditional_latents.shape[0] == b, f"batch size mismatch: {conditional_latents.shape[0]} vs {b}"
619
- batched_mel_emb = []
620
- attention_masks = []
621
- target_len = conditional_latents.shape[1] + L + 2
622
- for i in range(b):
623
- valid_mask = (text_inputs[i] != self.stop_text_token) & (text_inputs[i] != self.start_text_token)
624
- text_input = text_inputs[i][valid_mask]
625
- text_input = F.pad(text_input, (1, 0), value=self.start_text_token)
626
- text_input = F.pad(text_input, (0, 1), value=self.stop_text_token)
627
- text_input_pos = torch.arange(0, text_input.size(-1), device=device)
628
- text_emb = self.text_embedding(text_input) + self.text_pos_embedding.emb(text_input_pos)
629
- # concatenate [conditional latents][text embeddings]
630
- conds_text_emb = [
631
- conditional_latents.squeeze(0) if single_cond else conditional_latents[i],
632
- text_emb,
633
- ]
634
- # +1 for the start_mel_token
635
- attention_mask = torch.ones(target_len+1, dtype=torch.long, device=device)
636
- # check this text input is padded
637
- padding: int = L + 2 - text_input.size(-1)
638
- # pad left of [cond][text] -> [pad][cond][text]
639
- if padding > 0:
640
- pad = torch.zeros((padding, conditional_latents.size(-1)), dtype=text_emb.dtype, device=device) # [p, dim]
641
- conds_text_emb.insert(0, pad)
642
- attention_mask[:padding] = 0
643
- mel_emb = torch.cat(conds_text_emb) #[s, dim]
644
- assert mel_emb.shape[0] == target_len, f"mel_emb.shape: {mel_emb.shape}, target_len: {target_len}"
645
- batched_mel_emb.append(mel_emb)
646
- attention_masks.append(attention_mask)
647
- # [b, s, dim]
648
- batched_mel_emb = torch.stack(batched_mel_emb, dim=0)
649
- # [b, s+1]
650
- attention_mask = torch.stack(attention_masks, dim=0)
651
- # [b, s+1]
652
- fake_inputs = torch.ones(
653
- (
654
- batched_mel_emb.shape[0],
655
- batched_mel_emb.shape[1] + 1, # +1 for the start_mel_token
656
- ),
657
- dtype=torch.long,
658
- device=device,
659
- )
660
- fake_inputs[:, -1] = self.start_mel_token
661
- return fake_inputs, batched_mel_emb, attention_mask
662
-
663
- def inference_speech(self, speech_condition, text_inputs, emo_speech_condition=None, cond_lengths=None, emo_cond_lengths=None, emo_vec=None, use_speed=False, input_tokens=None, num_return_sequences=1,
664
- max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
665
- """
666
- Args:
667
- speech_condition: (b, d, frames) or (d, frames)
668
- text_inputs: (b, L)
669
- cond_mel_lengths: lengths of the conditioning mel spectrograms in shape (b,) or (1,)
670
- input_tokens: additional tokens for generation in shape (b, s) or (s,)
671
- max_generate_length: limit the number of generated tokens
672
- hf_generate_kwargs: kwargs for `GPT2InferenceModel.generate(**hf_generate_kwargs)`
673
- """
674
-
675
- if speech_condition.ndim == 2:
676
- speech_condition = speech_condition.unsqueeze(0)
677
- if emo_speech_condition is None:
678
- emo_speech_condition = speech_condition
679
- if cond_lengths is None:
680
- cond_lengths = torch.tensor([speech_condition.shape[-1]], device=speech_condition.device)
681
- if emo_cond_lengths is None:
682
- emo_cond_lengths = torch.tensor([emo_speech_condition.shape[-1]], device=speech_condition.device)
683
-
684
- speech_conditioning_latent = self.get_conditioning(speech_condition.transpose(1,2), cond_lengths)
685
- if emo_vec is None:
686
- print('compute emo vec')
687
- emo_vec = self.get_emo_conditioning(emo_speech_condition.transpose(1,2), emo_cond_lengths)
688
- emo_vec = self.emovec_layer(emo_vec)
689
- emo_vec = self.emo_layer(emo_vec)
690
- else:
691
- print('Use the specified emotion vector')
692
-
693
- tmp = torch.zeros(text_inputs.size(0)).to(text_inputs.device)
694
- duration_emb = self.speed_emb(torch.zeros_like(tmp).long())
695
- duration_emb_half = self.speed_emb(torch.ones_like(tmp).long())
696
- conds_latent = torch.cat((speech_conditioning_latent + emo_vec.unsqueeze(1), duration_emb_half.unsqueeze(1), duration_emb.unsqueeze(1)), 1)
697
- input_ids, inputs_embeds, attention_mask = self.prepare_gpt_inputs(conds_latent, text_inputs)
698
- self.inference_model.store_mel_emb(inputs_embeds)
699
- if input_tokens is None:
700
- inputs = input_ids
701
- else:
702
- if input_tokens.ndim == 1:
703
- input_tokens = input_tokens.unsqueeze(0)
704
- assert num_return_sequences % input_tokens.shape[0] == 0, \
705
- "The num_return_sequences must be divisible by the batch number of input_tokens"
706
- assert num_return_sequences % text_inputs.shape[0] == 0, \
707
- "The num_return_sequences must be divisible by the batch number of text_inputs"
708
- b = num_return_sequences // input_ids.shape[0]
709
- if b > 1:
710
- input_ids = input_ids.repeat(b, 1)
711
- attention_mask = attention_mask.repeat(b, 1)
712
- input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1)
713
- inputs = torch.cat([input_ids, input_tokens], dim=1)
714
- attention_mask = F.pad(attention_mask, (0, input_tokens.shape[1]), value=1)
715
- trunc_index = inputs.shape[1]
716
- logits_processor = LogitsProcessorList()
717
- if typical_sampling:
718
- # employ custom typical sampling
719
- if not (typical_mass > 0.0 and typical_mass < 1.0):
720
- raise ValueError(f"`typical_mass` has to be a float > 0 and < 1, but is {typical_mass}")
721
- min_tokens_to_keep = 2 if hf_generate_kwargs.get("num_beams", 1) > 1 else 1
722
- logits_processor.append(TypicalLogitsWarper(mass=typical_mass, min_tokens_to_keep=min_tokens_to_keep))
723
- max_length = (trunc_index + self.max_mel_tokens - 1) if max_generate_length is None else trunc_index + max_generate_length
724
- output = self.inference_model.generate(inputs,
725
- bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token,
726
- eos_token_id=self.stop_mel_token, attention_mask=attention_mask,
727
- max_length=max_length, logits_processor=logits_processor,
728
- num_return_sequences=num_return_sequences,
729
- **hf_generate_kwargs)
730
- if isinstance(output, torch.Tensor):
731
- return output[:, trunc_index:], speech_conditioning_latent
732
- # GenerateOutput
733
- output.sequences = output.sequences[:, trunc_index:]
734
- return output, speech_conditioning_latent
735
-
736
- def get_emovec(self, emo_speech_conditioning_latent, emo_cond_lengths):
737
- emo_vec_syn_ori = self.get_emo_conditioning(emo_speech_conditioning_latent.transpose(1,2), emo_cond_lengths)
738
- emo_vec_syn = self.emovec_layer(emo_vec_syn_ori)
739
- emo_vec = self.emo_layer(emo_vec_syn)
740
- return emo_vec
741
-
742
- def merge_emovec(self, speech_conditioning_latent, emo_speech_conditioning_latent, cond_lengths, emo_cond_lengths, alpha = 1.0):
743
- emo_vec = self.get_emovec(emo_speech_conditioning_latent, emo_cond_lengths)
744
- base_vec = self.get_emovec(speech_conditioning_latent, cond_lengths)
745
-
746
- out = base_vec + alpha * (emo_vec - base_vec)
747
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/gpt/perceiver.py DELETED
@@ -1,317 +0,0 @@
1
- # Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532
2
-
3
- from collections import namedtuple
4
- from functools import wraps
5
-
6
- import torch
7
- import torch.nn.functional as F
8
- from einops import rearrange, repeat
9
- from einops.layers.torch import Rearrange
10
- from packaging import version
11
- from torch import einsum, nn
12
-
13
-
14
- def exists(val):
15
- return val is not None
16
-
17
-
18
- def once(fn):
19
- called = False
20
-
21
- @wraps(fn)
22
- def inner(x):
23
- nonlocal called
24
- if called:
25
- return
26
- called = True
27
- return fn(x)
28
-
29
- return inner
30
-
31
-
32
- print_once = once(print)
33
-
34
-
35
- # main class
36
- class Attend(nn.Module):
37
- def __init__(self, dropout=0.0, causal=False, use_flash=False):
38
- super().__init__()
39
- self.dropout = dropout
40
- self.attn_dropout = nn.Dropout(dropout)
41
-
42
- self.causal = causal
43
- self.register_buffer("mask", None, persistent=False)
44
-
45
- self.use_flash = use_flash
46
- assert not (
47
- use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
48
- ), "in order to use flash attention, you must be using pytorch 2.0 or above"
49
-
50
- # determine efficient attention configs for cuda and cpu
51
- self.config = namedtuple("EfficientAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"])
52
- self.cpu_config = self.config(True, True, True)
53
- self.cuda_config = None
54
-
55
- if not torch.cuda.is_available() or not use_flash:
56
- return
57
-
58
- device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
59
-
60
- if device_properties.major == 8 and device_properties.minor == 0:
61
- print_once("A100 GPU detected, using flash attention if input tensor is on cuda")
62
- self.cuda_config = self.config(True, False, False)
63
- else:
64
- print_once("Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda")
65
- self.cuda_config = self.config(False, True, True)
66
-
67
- def get_mask(self, n, device):
68
- if exists(self.mask) and self.mask.shape[-1] >= n:
69
- return self.mask[:n, :n]
70
-
71
- mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
72
- self.register_buffer("mask", mask, persistent=False)
73
- return mask
74
-
75
- def flash_attn(self, q, k, v, mask=None):
76
- _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda
77
-
78
- # Recommended for multi-query single-key-value attention by Tri Dao
79
- # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
80
-
81
- if k.ndim == 3:
82
- k = rearrange(k, "b ... -> b 1 ...").expand_as(q)
83
-
84
- if v.ndim == 3:
85
- v = rearrange(v, "b ... -> b 1 ...").expand_as(q)
86
-
87
- # Check if mask exists and expand to compatible shape
88
- # The mask is B L, so it would have to be expanded to B H N L
89
-
90
- if exists(mask):
91
- mask = rearrange(mask, "b j -> b 1 1 j")
92
- mask = mask.expand(-1, heads, q_len, -1)
93
-
94
- # Check if there is a compatible device for flash attention
95
-
96
- config = self.cuda_config if is_cuda else self.cpu_config
97
-
98
- # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
99
-
100
- with torch.backends.cuda.sdp_kernel(**config._asdict()):
101
- out = F.scaled_dot_product_attention(
102
- q, k, v, attn_mask=mask, dropout_p=self.dropout if self.training else 0.0, is_causal=self.causal
103
- )
104
-
105
- return out
106
-
107
- def forward(self, q, k, v, mask=None):
108
- """
109
- einstein notation
110
- b - batch
111
- h - heads
112
- n, i, j - sequence length (base sequence length, source, target)
113
- d - feature dimension
114
- """
115
-
116
- n, device = q.shape[-2], q.device
117
-
118
- scale = q.shape[-1] ** -0.5
119
-
120
- if self.use_flash:
121
- return self.flash_attn(q, k, v, mask=mask)
122
-
123
- kv_einsum_eq = "b j d" if k.ndim == 3 else "b h j d"
124
-
125
- # similarity
126
-
127
- sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale
128
-
129
- # key padding mask
130
-
131
- if exists(mask):
132
- mask = rearrange(mask, "b j -> b 1 1 j")
133
- sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
134
-
135
- # causal mask
136
-
137
- if self.causal:
138
- causal_mask = self.get_mask(n, device)
139
- sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
140
-
141
- # attention
142
-
143
- attn = sim.softmax(dim=-1)
144
- attn = self.attn_dropout(attn)
145
-
146
- # aggregate values
147
-
148
- out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)
149
-
150
- return out
151
-
152
-
153
- def Sequential(*mods):
154
- return nn.Sequential(*filter(exists, mods))
155
-
156
-
157
- def exists(x):
158
- return x is not None
159
-
160
-
161
- def default(val, d):
162
- if exists(val):
163
- return val
164
- return d() if callable(d) else d
165
-
166
-
167
- class RMSNorm(nn.Module):
168
- def __init__(self, dim, scale=True, dim_cond=None):
169
- super().__init__()
170
- self.cond = exists(dim_cond)
171
- self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None
172
-
173
- self.scale = dim**0.5
174
- self.gamma = nn.Parameter(torch.ones(dim)) if scale else None
175
-
176
- def forward(self, x, cond=None):
177
- gamma = default(self.gamma, 1)
178
- out = F.normalize(x, dim=-1) * self.scale * gamma
179
-
180
- if not self.cond:
181
- return out
182
-
183
- assert exists(cond)
184
- gamma, beta = self.to_gamma_beta(cond).chunk(2, dim=-1)
185
- gamma, beta = map(lambda t: rearrange(t, "b d -> b 1 d"), (gamma, beta))
186
- return out * gamma + beta
187
-
188
-
189
- class CausalConv1d(nn.Conv1d):
190
- def __init__(self, *args, **kwargs):
191
- super().__init__(*args, **kwargs)
192
- (kernel_size,) = self.kernel_size
193
- (dilation,) = self.dilation
194
- (stride,) = self.stride
195
-
196
- assert stride == 1
197
- self.causal_padding = dilation * (kernel_size - 1)
198
-
199
- def forward(self, x):
200
- causal_padded_x = F.pad(x, (self.causal_padding, 0), value=0.0)
201
- return super().forward(causal_padded_x)
202
-
203
-
204
- class GEGLU(nn.Module):
205
- def forward(self, x):
206
- x, gate = x.chunk(2, dim=-1)
207
- return F.gelu(gate) * x
208
-
209
-
210
- def FeedForward(dim, mult=4, causal_conv=False):
211
- dim_inner = int(dim * mult * 2 / 3)
212
-
213
- conv = None
214
- if causal_conv:
215
- conv = nn.Sequential(
216
- Rearrange("b n d -> b d n"),
217
- CausalConv1d(dim_inner, dim_inner, 3),
218
- Rearrange("b d n -> b n d"),
219
- )
220
-
221
- return Sequential(nn.Linear(dim, dim_inner * 2), GEGLU(), conv, nn.Linear(dim_inner, dim))
222
-
223
-
224
- class PerceiverResampler(nn.Module):
225
- def __init__(
226
- self,
227
- dim,
228
- depth=2,
229
- dim_context=None,
230
- num_latents=32,
231
- dim_head=64,
232
- heads=8,
233
- ff_mult=4,
234
- use_flash_attn=False,
235
- ):
236
- super().__init__()
237
- dim_context = default(dim_context, dim)
238
-
239
- self.proj_context = nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity()
240
-
241
- self.latents = nn.Parameter(torch.randn(num_latents, dim))
242
- nn.init.normal_(self.latents, std=0.02)
243
-
244
- self.layers = nn.ModuleList([])
245
- for _ in range(depth):
246
- self.layers.append(
247
- nn.ModuleList(
248
- [
249
- Attention(
250
- dim=dim,
251
- dim_head=dim_head,
252
- heads=heads,
253
- use_flash=use_flash_attn,
254
- cross_attn_include_queries=True,
255
- ),
256
- FeedForward(dim=dim, mult=ff_mult),
257
- ]
258
- )
259
- )
260
-
261
- self.norm = RMSNorm(dim)
262
-
263
- def forward(self, x, mask=None):
264
- batch = x.shape[0]
265
-
266
- x = self.proj_context(x)
267
-
268
- latents = repeat(self.latents, "n d -> b n d", b=batch)
269
-
270
- for attn, ff in self.layers:
271
- latents = attn(latents, x, mask=mask) + latents
272
- latents = ff(latents) + latents
273
-
274
- return self.norm(latents)
275
-
276
-
277
- class Attention(nn.Module):
278
- def __init__(
279
- self,
280
- dim,
281
- *,
282
- dim_context=None,
283
- causal=False,
284
- dim_head=64,
285
- heads=8,
286
- dropout=0.0,
287
- use_flash=False,
288
- cross_attn_include_queries=False,
289
- ):
290
- super().__init__()
291
- self.scale = dim_head**-0.5
292
- self.heads = heads
293
- self.cross_attn_include_queries = cross_attn_include_queries
294
-
295
- dim_inner = dim_head * heads
296
- dim_context = default(dim_context, dim)
297
-
298
- self.attend = Attend(causal=causal, dropout=dropout, use_flash=use_flash)
299
- self.to_q = nn.Linear(dim, dim_inner, bias=False)
300
- self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False)
301
- self.to_out = nn.Linear(dim_inner, dim, bias=False)
302
-
303
- def forward(self, x, context=None, mask=None):
304
- h, has_context = self.heads, exists(context)
305
-
306
- context = default(context, x)
307
-
308
- if has_context and self.cross_attn_include_queries:
309
- context = torch.cat((x, context), dim=-2)
310
-
311
- q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1))
312
- q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
313
-
314
- out = self.attend(q, k, v, mask=mask)
315
-
316
- out = rearrange(out, "b h n d -> b n (h d)")
317
- return self.to_out(out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/gpt/transformers_beam_search.py DELETED
@@ -1,1013 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2020 The HuggingFace Inc. team
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from abc import ABC, abstractmethod
17
- from collections import UserDict
18
- from typing import Dict, List, Optional, Tuple, Union
19
-
20
- import numpy as np
21
- import torch
22
-
23
- from transformers.utils import add_start_docstrings
24
- from transformers.generation.beam_constraints import Constraint, ConstraintListState
25
-
26
-
27
- PROCESS_INPUTS_DOCSTRING = r"""
28
- Args:
29
- input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
30
- Indices of input sequence tokens in the vocabulary.
31
-
32
- Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
33
- [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
34
-
35
- [What are input IDs?](../glossary#input-ids)
36
- next_scores (`torch.FloatTensor` of shape `(batch_size, 2 * num_beams)`):
37
- Current scores of the top `2 * num_beams` non-finished beam hypotheses.
38
- next_tokens (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
39
- `input_ids` of the tokens corresponding to the top `2 * num_beams` non-finished beam hypotheses.
40
- next_indices (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
41
- Beam indices indicating to which beam hypothesis the `next_tokens` correspond.
42
- pad_token_id (`int`, *optional*):
43
- The id of the *padding* token.
44
- eos_token_id (`Union[int, List[int]]`, *optional*):
45
- The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
46
- beam_indices (`torch.LongTensor`, *optional*):
47
- Beam indices indicating to which beam hypothesis each token correspond.
48
- group_index (`int`, *optional*):
49
- The index of the group of beams. Used with [`~PreTrainedModel.group_beam_search`].
50
-
51
- Return:
52
- `UserDict`: A dictionary composed of the fields as defined above:
53
-
54
- - **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of all
55
- non-finished beams.
56
- - **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be added
57
- to the non-finished beam_hypotheses.
58
- - **next_beam_indices** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Beam indices
59
- indicating to which beam the next tokens shall be added.
60
-
61
- """
62
-
63
- FINALIZE_INPUTS_DOCSTRING = r"""
64
- Args:
65
- input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
66
- Indices of input sequence tokens in the vocabulary.
67
-
68
- Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
69
- [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
70
-
71
- [What are input IDs?](../glossary#input-ids)
72
- final_beam_scores (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
73
- The final scores of all non-finished beams.
74
- final_beam_tokens (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
75
- The last tokens to be added to the non-finished beam_hypotheses.
76
- final_beam_indices (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
77
- The beam indices indicating to which beam the `final_beam_tokens` shall be added.
78
- pad_token_id (`int`, *optional*):
79
- The id of the *padding* token.
80
- eos_token_id (`Union[int, List[int]]`, *optional*):
81
- The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
82
-
83
- Return:
84
- `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences.
85
- The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early
86
- due to the `eos_token_id`.
87
-
88
- """
89
-
90
-
91
- class BeamScorer(ABC):
92
- """
93
- Abstract base class for all beam scorers that are used for [`~PreTrainedModel.beam_search`] and
94
- [`~PreTrainedModel.beam_sample`].
95
- """
96
-
97
- @abstractmethod
98
- @add_start_docstrings(PROCESS_INPUTS_DOCSTRING)
99
- def process(
100
- self,
101
- input_ids: torch.LongTensor,
102
- next_scores: torch.FloatTensor,
103
- next_tokens: torch.LongTensor,
104
- next_indices: torch.LongTensor,
105
- **kwargs,
106
- ) -> Tuple[torch.Tensor]:
107
- raise NotImplementedError("This is an abstract method.")
108
-
109
- @abstractmethod
110
- @add_start_docstrings(FINALIZE_INPUTS_DOCSTRING)
111
- def finalize(
112
- self,
113
- input_ids: torch.LongTensor,
114
- next_scores: torch.FloatTensor,
115
- next_tokens: torch.LongTensor,
116
- next_indices: torch.LongTensor,
117
- max_length: int,
118
- **kwargs,
119
- ) -> torch.LongTensor:
120
- raise NotImplementedError("This is an abstract method.")
121
-
122
-
123
- class BeamSearchScorer(BeamScorer):
124
- r"""
125
- [`BeamScorer`] implementing standard beam search decoding.
126
-
127
- Adapted in part from [Facebook's XLM beam search
128
- code](https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529).
129
-
130
- Reference for the diverse beam search algorithm and implementation [Ashwin Kalyan's DBS
131
- implementation](https://github.com/ashwinkalyan/dbs/blob/master/dbs/beam_utils.lua)
132
-
133
- Args:
134
- batch_size (`int`):
135
- Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
136
- num_beams (`int`):
137
- Number of beams for beam search.
138
- device (`torch.device`):
139
- Defines the device type (*e.g.*, `"cpu"` or `"cuda"`) on which this instance of `BeamSearchScorer` will be
140
- allocated.
141
- length_penalty (`float`, *optional*, defaults to 1.0):
142
- Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
143
- the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
144
- likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
145
- `length_penalty` < 0.0 encourages shorter sequences.
146
- do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
147
- Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
148
- `True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
149
- heuristic is applied and the generation stops when is it very unlikely to find better candidates;
150
- `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
151
- beam search algorithm).
152
- num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
153
- The number of beam hypotheses that shall be returned upon calling
154
- [`~transformers.BeamSearchScorer.finalize`].
155
- num_beam_groups (`int`, *optional*, defaults to 1):
156
- Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
157
- See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
158
- max_length (`int`, *optional*):
159
- The maximum length of the sequence to be generated.
160
- """
161
-
162
- def __init__(
163
- self,
164
- batch_size: int,
165
- num_beams: int,
166
- device: torch.device,
167
- length_penalty: Optional[float] = 1.0,
168
- do_early_stopping: Optional[Union[bool, str]] = False,
169
- num_beam_hyps_to_keep: Optional[int] = 1,
170
- num_beam_groups: Optional[int] = 1,
171
- max_length: Optional[int] = None,
172
- ):
173
- self.num_beams = num_beams
174
- self.device = device
175
- self.length_penalty = length_penalty
176
- self.do_early_stopping = do_early_stopping
177
- self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
178
- self.num_beam_groups = num_beam_groups
179
- self.group_size = self.num_beams // self.num_beam_groups
180
-
181
- self._is_init = False
182
- # self._beam_hyps[i*self.num_beam_groups+j] is the beam_hyps of the j-th group in the i-th mini-batch.
183
- # If group_beam_search is not used, the list consists of `batch_size` beam_hyps.
184
- self._beam_hyps = [
185
- BeamHypotheses(
186
- num_beams=self.group_size,
187
- length_penalty=self.length_penalty,
188
- early_stopping=self.do_early_stopping,
189
- max_length=max_length,
190
- )
191
- for _ in range(batch_size * self.num_beam_groups)
192
- ]
193
- # self._done[i*self.num_beam_groups+j] indicates whether the generation of the beam_hyps of the j-th group
194
- # in the i-th mini-batch is complete.
195
- self._done = torch.tensor(
196
- [False for _ in range(batch_size * self.num_beam_groups)], dtype=torch.bool, device=self.device
197
- )
198
-
199
- if not isinstance(num_beams, int) or num_beams <= 1:
200
- raise ValueError(
201
- f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
202
- " one should make use of `greedy_search` instead."
203
- )
204
-
205
- if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
206
- raise ValueError(
207
- "`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be"
208
- f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
209
- )
210
-
211
- @property
212
- def is_done(self) -> bool:
213
- return self._done.all()
214
-
215
- def process(
216
- self,
217
- input_ids: torch.LongTensor,
218
- next_scores: torch.FloatTensor,
219
- next_tokens: torch.LongTensor,
220
- next_indices: torch.LongTensor,
221
- pad_token_id: Optional[Union[int, torch.Tensor]] = None,
222
- eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
223
- beam_indices: Optional[torch.LongTensor] = None,
224
- group_index: Optional[int] = 0,
225
- decoder_prompt_len: Optional[int] = 0,
226
- ) -> Dict[str, torch.Tensor]:
227
- # add up to the length which the next_scores is calculated on (including decoder prompt)
228
- cur_len = input_ids.shape[-1] + 1
229
- batch_size = len(self._beam_hyps) // self.num_beam_groups
230
-
231
- if not (batch_size == (input_ids.shape[0] // self.group_size)):
232
- if self.num_beam_groups > 1:
233
- raise ValueError(
234
- f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam "
235
- f"size of {self.group_size} is expected by the beam scorer."
236
- )
237
- else:
238
- raise ValueError(
239
- f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of "
240
- f"{self.group_size} is expected by the beam scorer."
241
- )
242
-
243
- device = input_ids.device
244
- next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
245
- next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
246
- next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
247
-
248
- if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
249
- if isinstance(eos_token_id, int):
250
- eos_token_id = [eos_token_id]
251
- eos_token_id = torch.tensor(eos_token_id)
252
-
253
- for batch_idx in range(batch_size):
254
- batch_group_idx = batch_idx * self.num_beam_groups + group_index
255
- if self._done[batch_group_idx]:
256
- if self.num_beams < len(self._beam_hyps[batch_group_idx]):
257
- raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
258
- if eos_token_id is None or pad_token_id is None:
259
- raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
260
- # pad the batch
261
- next_beam_scores[batch_idx, :] = 0
262
- next_beam_tokens[batch_idx, :] = pad_token_id
263
- next_beam_indices[batch_idx, :] = 0
264
- continue
265
-
266
- # next tokens for this sentence
267
- beam_idx = 0
268
- for beam_token_rank, (next_token, next_score, next_index) in enumerate(
269
- zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
270
- ):
271
- batch_beam_idx = batch_idx * self.group_size + next_index
272
- # add to generated hypotheses if end of sentence
273
- if (eos_token_id is not None) and (next_token.item() in eos_token_id):
274
- # if beam_token does not belong to top num_beams tokens, it should not be added
275
- is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
276
- if is_beam_token_worse_than_top_num_beams:
277
- continue
278
- if beam_indices is not None:
279
- beam_index = beam_indices[batch_beam_idx]
280
- beam_index = beam_index + (batch_beam_idx,)
281
- else:
282
- beam_index = None
283
-
284
- self._beam_hyps[batch_group_idx].add(
285
- input_ids[batch_beam_idx].clone(),
286
- next_score.item(),
287
- beam_indices=beam_index,
288
- generated_len=cur_len - decoder_prompt_len,
289
- )
290
- else:
291
- # add next predicted token since it is not eos_token
292
- next_beam_scores[batch_idx, beam_idx] = next_score
293
- next_beam_tokens[batch_idx, beam_idx] = next_token
294
- next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
295
- beam_idx += 1
296
-
297
- # once the beam for next step is full, don't add more tokens to it.
298
- if beam_idx == self.group_size:
299
- break
300
-
301
- if beam_idx < self.group_size:
302
- raise ValueError(
303
- f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"
304
- f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
305
- )
306
-
307
- # Check if we are done so that we can save a pad step if all(done)
308
- self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done(
309
- next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
310
- )
311
-
312
- return UserDict(
313
- {
314
- "next_beam_scores": next_beam_scores.view(-1),
315
- "next_beam_tokens": next_beam_tokens.view(-1),
316
- "next_beam_indices": next_beam_indices.view(-1),
317
- }
318
- )
319
-
320
- def finalize(
321
- self,
322
- input_ids: torch.LongTensor,
323
- final_beam_scores: torch.FloatTensor,
324
- final_beam_tokens: torch.LongTensor,
325
- final_beam_indices: torch.LongTensor,
326
- max_length: int,
327
- pad_token_id: Optional[Union[int, torch.Tensor]] = None,
328
- eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
329
- beam_indices: Optional[torch.LongTensor] = None,
330
- decoder_prompt_len: Optional[int] = 0,
331
- ) -> Tuple[torch.LongTensor]:
332
- batch_size = len(self._beam_hyps) // self.num_beam_groups
333
-
334
- if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
335
- if isinstance(eos_token_id, int):
336
- eos_token_id = [eos_token_id]
337
- eos_token_id = torch.tensor(eos_token_id)
338
-
339
- # finalize all open beam hypotheses and add to generated hypotheses
340
- for batch_group_idx, beam_hyp in enumerate(self._beam_hyps):
341
- if self._done[batch_group_idx]:
342
- continue
343
-
344
- # all open beam hypotheses are added to the beam hypothesis
345
- # beam hypothesis class automatically keeps the best beams
346
- for index_per_group in range(self.group_size):
347
- batch_beam_idx = batch_group_idx * self.group_size + index_per_group
348
- final_score = final_beam_scores[batch_beam_idx].item()
349
- final_tokens = input_ids[batch_beam_idx]
350
- beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
351
- generated_len = final_tokens.shape[-1] - decoder_prompt_len
352
- beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)
353
-
354
- # select the best hypotheses
355
- sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
356
- best = []
357
- best_indices = []
358
- best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
359
-
360
- # retrieve best hypotheses
361
- for i in range(batch_size):
362
- beam_hyps_in_batch = self._beam_hyps[i * self.num_beam_groups : (i + 1) * self.num_beam_groups]
363
- candidate_beams = [beam for beam_hyp in beam_hyps_in_batch for beam in beam_hyp.beams]
364
- sorted_hyps = sorted(candidate_beams, key=lambda x: x[0])
365
- for j in range(self.num_beam_hyps_to_keep):
366
- best_hyp_tuple = sorted_hyps.pop()
367
- best_score = best_hyp_tuple[0]
368
- best_hyp = best_hyp_tuple[1]
369
- best_index = best_hyp_tuple[2]
370
- sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
371
-
372
- # append hyp to lists
373
- best.append(best_hyp)
374
-
375
- # append indices to list
376
- best_indices.append(best_index)
377
-
378
- best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
379
-
380
- # prepare for adding eos
381
- sent_lengths_max = sent_lengths.max().item() + 1
382
- sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
383
- decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
384
-
385
- if len(best_indices) > 0 and best_indices[0] is not None:
386
- indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
387
- else:
388
- indices = None
389
-
390
- # shorter batches are padded if needed
391
- if sent_lengths.min().item() != sent_lengths.max().item():
392
- if pad_token_id is None:
393
- raise ValueError("`pad_token_id` has to be defined")
394
- decoded.fill_(pad_token_id)
395
-
396
- if indices is not None:
397
- indices.fill_(-1)
398
-
399
- # fill with hypotheses and eos_token_id if the latter fits in
400
- for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
401
- decoded[i, : sent_lengths[i]] = hypo
402
-
403
- if indices is not None:
404
- indices[i, : len(best_idx)] = torch.tensor(best_idx)
405
-
406
- if sent_lengths[i] < sent_max_len:
407
- # inserting only the first eos_token_id
408
- decoded[i, sent_lengths[i]] = eos_token_id[0]
409
-
410
- return UserDict(
411
- {
412
- "sequences": decoded,
413
- "sequence_scores": best_scores,
414
- "beam_indices": indices,
415
- }
416
- )
417
-
418
-
419
- class ConstrainedBeamSearchScorer(BeamScorer):
420
- r"""
421
- [`BeamScorer`] implementing constrained beam search decoding.
422
-
423
-
424
- Args:
425
- batch_size (`int`):
426
- Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
427
- num_beams (`int`):
428
- Number of beams for beam search.
429
- constraints (`List[Constraint]`):
430
- A list of positive constraints represented as `Constraint` objects that must be fulfilled in the generation
431
- output. For more information, the documentation of [`Constraint`] should be read.
432
- device (`torch.device`):
433
- Defines the device type (*e.g.*, `"cpu"` or `"cuda"`) on which this instance of `BeamSearchScorer` will be
434
- allocated.
435
- length_penalty (`float`, *optional*, defaults to 1.0):
436
- Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
437
- the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
438
- likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
439
- `length_penalty` < 0.0 encourages shorter sequences.
440
- do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
441
- Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
442
- `True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
443
- heuristic is applied and the generation stops when is it very unlikely to find better candidates;
444
- `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
445
- beam search algorithm).
446
- num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
447
- The number of beam hypotheses that shall be returned upon calling
448
- [`~transformers.BeamSearchScorer.finalize`].
449
- num_beam_groups (`int`, *optional*, defaults to 1):
450
- Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
451
- See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
452
- max_length (`int`, *optional*):
453
- The maximum length of the sequence to be generated.
454
- """
455
-
456
- def __init__(
457
- self,
458
- batch_size: int,
459
- num_beams: int,
460
- constraints: List[Constraint],
461
- device: torch.device,
462
- length_penalty: Optional[float] = 1.0,
463
- do_early_stopping: Optional[Union[bool, str]] = False,
464
- num_beam_hyps_to_keep: Optional[int] = 1,
465
- num_beam_groups: Optional[int] = 1,
466
- max_length: Optional[int] = None,
467
- ):
468
- self.num_beams = num_beams
469
- self.device = device
470
- self.length_penalty = length_penalty
471
- self.do_early_stopping = do_early_stopping
472
- self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
473
- self.num_beam_groups = num_beam_groups
474
- self.group_size = self.num_beams // self.num_beam_groups
475
- self.constraints = constraints
476
-
477
- self._is_init = False
478
- self._beam_hyps = [
479
- BeamHypotheses(
480
- num_beams=self.num_beams,
481
- length_penalty=self.length_penalty,
482
- early_stopping=self.do_early_stopping,
483
- max_length=max_length,
484
- )
485
- for _ in range(batch_size)
486
- ]
487
- self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device)
488
-
489
- if not isinstance(num_beams, int) or num_beams <= 1:
490
- raise ValueError(
491
- f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
492
- " one should make use of `greedy_search` instead."
493
- )
494
-
495
- if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
496
- raise ValueError(
497
- "`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be"
498
- f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
499
- )
500
-
501
- @property
502
- def is_done(self) -> bool:
503
- return self._done.all()
504
-
505
- def make_constraint_states(self, n):
506
- return [ConstraintListState([constraint.copy() for constraint in self.constraints]) for _ in range(n)]
507
-
508
- def check_completes_constraints(self, sequence):
509
- new_state = self.make_constraint_states(1)[0]
510
- new_state.reset(sequence)
511
- return new_state.completed
512
-
513
- def process(
514
- self,
515
- input_ids: torch.LongTensor,
516
- next_scores: torch.FloatTensor,
517
- next_tokens: torch.LongTensor,
518
- next_indices: torch.LongTensor,
519
- scores_for_all_vocab: torch.FloatTensor,
520
- pad_token_id: Optional[Union[int, torch.Tensor]] = None,
521
- eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
522
- beam_indices: Optional[torch.LongTensor] = None,
523
- decoder_prompt_len: Optional[int] = 0,
524
- ) -> Tuple[torch.Tensor]:
525
- r"""
526
- Args:
527
- input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
528
- Indices of input sequence tokens in the vocabulary.
529
-
530
- Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
531
- [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
532
-
533
- [What are input IDs?](../glossary#input-ids)
534
- next_scores (`torch.FloatTensor` of shape `(batch_size, 2 * num_beams)`):
535
- Current scores of the top `2 * num_beams` non-finished beam hypotheses.
536
- next_tokens (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
537
- `input_ids` of the tokens corresponding to the top `2 * num_beams` non-finished beam hypotheses.
538
- next_indices (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
539
- Beam indices indicating to which beam hypothesis the `next_tokens` correspond.
540
- scores_for_all_vocab (`torch.FloatTensor` of shape `(batch_size * num_beams, sequence_length)`):
541
- The scores of all tokens in the vocabulary for each of the beam hypotheses.
542
- pad_token_id (`int`, *optional*):
543
- The id of the *padding* token.
544
- eos_token_id (`Union[int, List[int]]`, *optional*):
545
- The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
546
- beam_indices (`torch.LongTensor`, *optional*):
547
- Beam indices indicating to which beam hypothesis each token correspond.
548
- decoder_prompt_len (`int`, *optional*):
549
- The length of prompt that is included in the input to decoder.
550
- Return:
551
- `UserDict`: A dictionary composed of the fields as defined above:
552
-
553
- - **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of
554
- all
555
- non-finished beams.
556
-
557
- - **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be
558
- added
559
- to the non-finished beam_hypotheses.
560
- - **next_beam_indices** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Beam indices
561
- indicating to which beam the next tokens shall be added.
562
- """
563
-
564
- # add up to the length which the next_scores is calculated on (including decoder prompt)
565
- cur_len = input_ids.shape[-1] + 1
566
- batch_size = len(self._beam_hyps)
567
- if not (batch_size == (input_ids.shape[0] // self.group_size)):
568
- if self.num_beam_groups > 1:
569
- raise ValueError(
570
- f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam "
571
- f"size of {self.group_size} is expected by the beam scorer."
572
- )
573
- else:
574
- raise ValueError(
575
- f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of "
576
- f"{self.group_size} is expected by the beam scorer."
577
- )
578
-
579
- device = input_ids.device
580
-
581
- next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
582
- next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
583
- next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
584
-
585
- if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
586
- if isinstance(eos_token_id, int):
587
- eos_token_id = [eos_token_id]
588
- eos_token_id = torch.tensor(eos_token_id)
589
-
590
- for batch_idx, beam_hyp in enumerate(self._beam_hyps):
591
- if self._done[batch_idx]:
592
- if self.num_beams < len(beam_hyp):
593
- raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
594
- if eos_token_id is None or pad_token_id is None:
595
- raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
596
- # pad the batch
597
- next_beam_scores[batch_idx, :] = 0
598
- next_beam_tokens[batch_idx, :] = pad_token_id
599
- next_beam_indices[batch_idx, :] = 0
600
- continue
601
-
602
- # next tokens for this sentence.
603
- beam_idx = 0
604
- for beam_token_rank, (next_token, next_score, next_index) in enumerate(
605
- zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
606
- ):
607
- batch_beam_idx = batch_idx * self.group_size + next_index
608
- # add to generated hypotheses if end of sentence
609
- if (eos_token_id is not None) and (next_token.item() in eos_token_id):
610
- # if beam_token does not belong to top num_beams tokens, it should not be added
611
- is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
612
- if is_beam_token_worse_than_top_num_beams:
613
- continue
614
-
615
- completes_constraint = self.check_completes_constraints(input_ids[batch_beam_idx].cpu().tolist())
616
- if completes_constraint:
617
- if beam_indices is not None:
618
- beam_index = beam_indices[batch_beam_idx]
619
- beam_index = beam_index + (batch_beam_idx,)
620
- else:
621
- beam_index = None
622
-
623
- beam_hyp.add(
624
- input_ids[batch_beam_idx].clone(),
625
- next_score.item(),
626
- beam_indices=beam_index,
627
- generated_len=cur_len - decoder_prompt_len,
628
- )
629
- else:
630
- # add next predicted token since it is not eos_token
631
- next_beam_scores[batch_idx, beam_idx] = next_score
632
- next_beam_tokens[batch_idx, beam_idx] = next_token
633
- next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
634
- beam_idx += 1
635
-
636
- # once the beam for next step is full, don't add more tokens to it.
637
- if beam_idx == self.group_size:
638
- break
639
-
640
- new_scores, new_tokens, new_indices = self.step_sentence_constraint(
641
- batch_idx,
642
- input_ids,
643
- scores_for_all_vocab,
644
- next_beam_scores[batch_idx],
645
- next_beam_tokens[batch_idx],
646
- next_beam_indices[batch_idx],
647
- )
648
-
649
- next_beam_scores[batch_idx] = new_scores
650
- next_beam_tokens[batch_idx] = new_tokens
651
- next_beam_indices[batch_idx] = new_indices
652
-
653
- if beam_idx < self.group_size:
654
- raise ValueError(
655
- f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"
656
- f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
657
- )
658
-
659
- # Check if we are done so that we can save a pad step if all(done)
660
- self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
661
- next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
662
- )
663
-
664
- return UserDict(
665
- {
666
- "next_beam_scores": next_beam_scores.view(-1),
667
- "next_beam_tokens": next_beam_tokens.view(-1),
668
- "next_beam_indices": next_beam_indices.view(-1),
669
- }
670
- )
671
-
672
- def step_sentence_constraint(
673
- self,
674
- batch_idx: int,
675
- input_ids: torch.LongTensor,
676
- vocab_scores: torch.FloatTensor,
677
- sent_beam_scores: torch.FloatTensor,
678
- sent_beam_tokens: torch.LongTensor,
679
- sent_beam_indices: torch.LongTensor,
680
- push_progress: bool = False,
681
- ):
682
- # sent_beam_tokens are the next {num_beams} number of tokens that are under consideration for this beam
683
- # (candidate next tokens)
684
-
685
- # 1. Adding "advance_tokens"
686
- # using ConstraintStateList.advance(), we propose new tokens to be added into this "candidate list" that will
687
- # advance us in fulfilling the constraints.
688
-
689
- # 2. Selecting best candidates such that we end up with highest probable candidates
690
- # that fulfill our constraints.
691
-
692
- orig_len = sent_beam_indices.size(0)
693
- device = sent_beam_indices.device
694
-
695
- # initialize states
696
- topk_contraint_states = self.make_constraint_states(orig_len)
697
- advance_constraint_states = self.make_constraint_states(orig_len)
698
-
699
- sidx, eidx = batch_idx * orig_len, (batch_idx + 1) * orig_len
700
- this_batch_input_ids = input_ids[sidx:eidx]
701
- this_batch_token_scores = vocab_scores[sidx:eidx]
702
- full_hypotheses = torch.cat((input_ids[sent_beam_indices], sent_beam_tokens.unsqueeze(-1)), dim=-1)
703
-
704
- # need to make new hypothesis that advance the constraints
705
- track_new = {
706
- "new_seqs": full_hypotheses.tolist(),
707
- "new_states": [],
708
- "new_indices": [],
709
- "new_tokens": [],
710
- "new_scores": [],
711
- }
712
- for seq_idx, pre_seq in enumerate(this_batch_input_ids):
713
- # pre_seq = ith sequence generated before this step.
714
-
715
- # input_ids -> (topk) generic beam search best model next tokens
716
- # -> (advance) constraints forcing the next token
717
- # either way, we need to sort them into "banks" later, so store a "ConstraintListState" for all types of
718
- # hypotheses.
719
-
720
- topk_state = topk_contraint_states[seq_idx]
721
- topk_state.reset(full_hypotheses[seq_idx].cpu().tolist())
722
-
723
- advance_state = advance_constraint_states[seq_idx]
724
- advance_state.reset(pre_seq.cpu().tolist())
725
-
726
- if not advance_state.completed:
727
- advance_tokens = torch.LongTensor(advance_state.advance()).to(device)
728
- for advance_token in advance_tokens:
729
- # since adding each `advance_token` leads to a different hypothesis, create new state instance.
730
- new_state = advance_state.copy(stateful=True)
731
- new_state.add(advance_token.cpu().tolist())
732
-
733
- advance_seq = torch.cat((pre_seq, advance_token.unsqueeze(0)), -1).cpu().tolist()
734
- if advance_seq not in track_new["new_seqs"]:
735
- # prevent duplicates, which are basically bound to happen in this process.
736
- track_new["new_seqs"].append(advance_seq)
737
- track_new["new_indices"].append(sidx + seq_idx) # idx -> global idx across all the batches
738
- track_new["new_tokens"].append(advance_token)
739
- track_new["new_scores"].append(this_batch_token_scores[seq_idx].take(advance_token))
740
- track_new["new_states"].append(new_state)
741
- elif push_progress:
742
- # Basically, `sent_beam_indices` often chooses very little among `input_ids` the generated sequences that
743
- # actually fulfill our constraints. For example, let constraints == ["loves pies"] and
744
-
745
- # pre_seq_1 = "The child loves pies and" pre_seq_2 = "The child plays in the playground and"
746
-
747
- # Without this step, if `sent_beam_indices` is something like [1,1], then
748
- # 1. `pre_seq_1` won't be added to the list of (topk) hypothesis since it's not in the indices and
749
- # 2. it won't be added to the list of (advance) hypothesis since it's completed already. (this is
750
- # the else part of `if constraints_completed[seq_idx]`)
751
- # 3. it ends up simply getting removed from consideration.
752
-
753
- # #3 might be fine and actually desired, since it's likely that it's a low-probability output anyways,
754
- # especially if it's not in the list of `sent_beam_indices`. But this often leads to lengthened beam
755
- # search times, since completed sequences keep getting removed after all this effort for constrained
756
- # generation.
757
-
758
- # Here, we basically take `pre_seq_1` and to "push" it into the considered list of hypotheses, by simply
759
- # appending the next likely token in the vocabulary and adding it to the list of hypotheses.
760
-
761
- new_score, new_token = torch.max(this_batch_token_scores[seq_idx], 0) # some next probable token
762
- advance_seq = torch.cat((pre_seq, new_token.unsqueeze(0)), -1)
763
-
764
- advance_state = advance_constraint_states[seq_idx]
765
-
766
- advance_seq = advance_seq.cpu().tolist()
767
-
768
- advance_state.reset(advance_seq)
769
- if advance_seq not in track_new["new_seqs"]:
770
- # but still don't want to have duplicates
771
- track_new["new_seqs"].append(advance_seq)
772
- track_new["new_indices"].append(seq_idx)
773
- track_new["new_tokens"].append(new_token)
774
- track_new["new_scores"].append(new_score)
775
- track_new["new_states"].append(advance_state)
776
-
777
- if len(track_new["new_indices"]) > 0:
778
- new_indices = torch.tensor(track_new["new_indices"]).to(device)
779
- new_tokens = torch.stack(track_new["new_tokens"]).to(device)
780
- new_scores = torch.stack(track_new["new_scores"]).to(device)
781
-
782
- all_states = topk_contraint_states + track_new["new_states"]
783
- all_tokens = torch.cat((sent_beam_tokens, new_tokens), -1)
784
- all_scores = torch.cat((sent_beam_scores, new_scores), -1)
785
- all_banks = torch.tensor([one.get_bank() for one in all_states]).to(device)
786
-
787
- zipped = all_banks * 100 + all_scores
788
- indices = zipped.sort(descending=True).indices
789
- sorted_banks = all_banks[indices]
790
-
791
- # Then we end up with {sorted among bank C}, {sorted among bank C-1}, ..., {sorted among bank 0}
792
-
793
- counter = -1
794
- cur_bank = sorted_banks[0]
795
- increments = []
796
- for bank in sorted_banks:
797
- if bank == cur_bank:
798
- counter += 1
799
- else:
800
- counter = 0
801
- cur_bank = bank
802
- increments.append(counter)
803
- rearrangers = torch.tensor(np.argsort(increments, kind="mergesort"))
804
-
805
- indices = indices[rearrangers][:orig_len]
806
-
807
- sent_beam_scores = all_scores[indices]
808
- sent_beam_tokens = all_tokens[indices]
809
- sent_beam_indices = torch.cat((sent_beam_indices, new_indices))[indices]
810
-
811
- return sent_beam_scores, sent_beam_tokens, sent_beam_indices
812
-
813
- def finalize(
814
- self,
815
- input_ids: torch.LongTensor,
816
- final_beam_scores: torch.FloatTensor,
817
- final_beam_tokens: torch.LongTensor,
818
- final_beam_indices: torch.LongTensor,
819
- max_length: int,
820
- pad_token_id: Optional[Union[int, torch.Tensor]] = None,
821
- eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
822
- beam_indices: Optional[torch.LongTensor] = None,
823
- decoder_prompt_len: Optional[int] = 0,
824
- ) -> Tuple[torch.LongTensor]:
825
- batch_size = len(self._beam_hyps)
826
-
827
- if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
828
- if isinstance(eos_token_id, int):
829
- eos_token_id = [eos_token_id]
830
- eos_token_id = torch.tensor(eos_token_id)
831
-
832
- # finalize all open beam hypotheses and add to generated hypotheses
833
- for batch_idx, beam_hyp in enumerate(self._beam_hyps):
834
- if self._done[batch_idx]:
835
- continue
836
-
837
- # all open beam hypotheses are added to the beam hypothesis
838
- # beam hypothesis class automatically keeps the best beams
839
-
840
- ids_collect = []
841
- for beam_id in range(self.num_beams):
842
- batch_beam_idx = batch_idx * self.num_beams + beam_id
843
- final_score = final_beam_scores[batch_beam_idx].item()
844
- final_tokens = input_ids[batch_beam_idx]
845
-
846
- completes_constraint = self.check_completes_constraints(final_tokens.cpu().tolist())
847
- if completes_constraint:
848
- beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
849
- generated_len = final_tokens.shape[-1] - decoder_prompt_len
850
- beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)
851
- ids_collect.append(beam_id)
852
-
853
- # due to overly complex constraints or other factors, sometimes we can't gaurantee a successful
854
- # generation. In these cases we simply return the highest scoring outputs.
855
- if len(ids_collect) < self.num_beam_hyps_to_keep:
856
- for beam_id in range(self.num_beams):
857
- if beam_id not in ids_collect:
858
- batch_beam_idx = batch_idx * self.num_beams + beam_id
859
- final_score = final_beam_scores[batch_beam_idx].item()
860
- final_tokens = input_ids[batch_beam_idx]
861
- generated_len = final_tokens.shape[-1] - decoder_prompt_len
862
- beam_hyp.add(final_tokens, final_score, generated_len=generated_len)
863
- if len(ids_collect) >= self.num_beam_hyps_to_keep:
864
- break
865
-
866
- # select the best hypotheses
867
- sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
868
- best = []
869
- best_indices = []
870
- best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
871
-
872
- # retrieve best hypotheses
873
- for i, beam_hyp in enumerate(self._beam_hyps):
874
- sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
875
- for j in range(self.num_beam_hyps_to_keep):
876
- best_hyp_tuple = sorted_hyps.pop()
877
- best_score = best_hyp_tuple[0]
878
- best_hyp = best_hyp_tuple[1]
879
- best_index = best_hyp_tuple[2]
880
- sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
881
-
882
- # append to lists
883
- best.append(best_hyp)
884
-
885
- # append indices to list
886
- best_indices.append(best_index)
887
-
888
- best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
889
-
890
- # prepare for adding eos
891
- sent_lengths_max = sent_lengths.max().item() + 1
892
-
893
- sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
894
- decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
895
-
896
- if len(best_indices) > 0 and best_indices[0] is not None:
897
- indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
898
- else:
899
- indices = None
900
-
901
- # shorter batches are padded if needed
902
- if sent_lengths.min().item() != sent_lengths.max().item():
903
- if pad_token_id is None:
904
- raise ValueError("`pad_token_id` has to be defined")
905
- decoded.fill_(pad_token_id)
906
-
907
- if indices is not None:
908
- indices.fill_(-1)
909
-
910
- # fill with hypotheses and eos_token_id if the latter fits in
911
- for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
912
- decoded[i, : sent_lengths[i]] = hypo
913
-
914
- if indices is not None:
915
- indices[i, : len(best_idx)] = torch.tensor(best_idx)
916
-
917
- if sent_lengths[i] < sent_max_len:
918
- # inserting only the first eos_token_id
919
- decoded[i, sent_lengths[i]] = eos_token_id[0]
920
-
921
- return UserDict(
922
- {
923
- "sequences": decoded,
924
- "sequence_scores": best_scores,
925
- "beam_indices": indices,
926
- }
927
- )
928
-
929
-
930
- class BeamHypotheses:
931
- def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool, max_length: Optional[int] = None):
932
- """
933
- Initialize n-best list of hypotheses.
934
- """
935
- self.length_penalty = length_penalty
936
- self.early_stopping = early_stopping
937
- self.max_length = max_length
938
- self.num_beams = num_beams
939
- self.beams = []
940
- self.worst_score = 1e9
941
-
942
- if not isinstance(self.early_stopping, bool) and self.max_length is None:
943
- raise ValueError(
944
- "When `do_early_stopping` is set to a string, `max_length` must be defined. Ensure it is passed to the"
945
- " BeamScorer class instance at initialization time."
946
- )
947
-
948
- def __len__(self):
949
- """
950
- Number of hypotheses in the list.
951
- """
952
- return len(self.beams)
953
-
954
- def add(
955
- self,
956
- hyp: torch.LongTensor,
957
- sum_logprobs: float,
958
- beam_indices: Optional[torch.LongTensor] = None,
959
- generated_len: Optional[int] = None,
960
- ):
961
- """
962
- Add a new hypothesis to the list.
963
- """
964
- if generated_len is not None:
965
- score = sum_logprobs / (generated_len**self.length_penalty)
966
- # This 'else' case exists for retrocompatibility
967
- else:
968
- score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
969
-
970
- if len(self) < self.num_beams or score > self.worst_score:
971
- self.beams.append((score, hyp, beam_indices))
972
- if len(self) > self.num_beams:
973
- sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
974
- del self.beams[sorted_next_scores[0][1]]
975
- self.worst_score = sorted_next_scores[1][0]
976
- else:
977
- self.worst_score = min(score, self.worst_score)
978
-
979
- def is_done(self, best_sum_logprobs: float, cur_len: int, decoder_prompt_len: Optional[int] = 0) -> bool:
980
- """
981
- If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
982
- one in the heap, then we are done with this sentence.
983
- """
984
-
985
- if len(self) < self.num_beams:
986
- return False
987
-
988
- # `True`: stop as soon as at least `num_beams` hypotheses are finished
989
- if self.early_stopping is True:
990
- return True
991
- # `False`: heuristic -- compute best possible score from `cur_len`, even though it is not entirely accurate
992
- # when `length_penalty` is positive. See the discussion below for more details.
993
- # https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
994
- elif self.early_stopping is False:
995
- highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
996
- ret = self.worst_score >= highest_attainable_score
997
- return ret
998
- # `"never"`: compute the best possible score, depending on the signal of `length_penalty`
999
- else:
1000
- # `length_penalty` > 0.0 -> max denominator is obtaned from `max_length`, not from `cur_len` -> min
1001
- # abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain
1002
- # its max this way
1003
- if self.length_penalty > 0.0:
1004
- if self.max_length <= decoder_prompt_len:
1005
- raise ValueError("max_length is not larger than decoder prompt length")
1006
- highest_attainable_score = (
1007
- best_sum_logprobs / (self.max_length - decoder_prompt_len) ** self.length_penalty
1008
- )
1009
- # the opposite logic applies here (max `highest_attainable_score` from `cur_len`)
1010
- else:
1011
- highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
1012
- ret = self.worst_score >= highest_attainable_score
1013
- return ret
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/gpt/transformers_generation_utils.py DELETED
The diff for this file is too large to render. See raw diff
 
indextts/gpt/transformers_gpt2.py DELETED
@@ -1,1878 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3
- # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """PyTorch OpenAI GPT-2 model."""
17
-
18
- import math
19
- import os
20
- import warnings
21
- from dataclasses import dataclass
22
- from typing import Optional, Tuple, Union
23
-
24
- import torch
25
- import torch.utils.checkpoint
26
- from packaging import version
27
- from torch import nn
28
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
-
30
- from transformers.activations import ACT2FN
31
- import transformers
32
-
33
- from indextts.gpt.transformers_generation_utils import GenerationMixin
34
- from indextts.gpt.transformers_modeling_utils import PreTrainedModel
35
- from transformers.modeling_utils import SequenceSummary
36
-
37
- from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
38
- from transformers.modeling_outputs import (
39
- BaseModelOutputWithPastAndCrossAttentions,
40
- CausalLMOutputWithCrossAttentions,
41
- QuestionAnsweringModelOutput,
42
- SequenceClassifierOutputWithPast,
43
- TokenClassifierOutput,
44
- )
45
- # from transformers.modeling_utils import PreTrainedModel, SequenceSummary
46
-
47
- from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
48
- from transformers.utils import (
49
- ModelOutput,
50
- add_code_sample_docstrings,
51
- add_start_docstrings,
52
- add_start_docstrings_to_model_forward,
53
- get_torch_version,
54
- is_flash_attn_2_available,
55
- is_flash_attn_greater_or_equal_2_10,
56
- logging,
57
- replace_return_docstrings,
58
- )
59
- from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
60
- from transformers.models.gpt2.configuration_gpt2 import GPT2Config
61
-
62
-
63
- if is_flash_attn_2_available():
64
- from transformers.modeling_flash_attention_utils import _flash_attention_forward
65
-
66
-
67
- logger = logging.get_logger(__name__)
68
-
69
- _CHECKPOINT_FOR_DOC = "openai-community/gpt2"
70
- _CONFIG_FOR_DOC = "GPT2Config"
71
-
72
-
73
- def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
74
- """Load tf checkpoints in a pytorch model"""
75
- try:
76
- import re
77
-
78
- import tensorflow as tf
79
- except ImportError:
80
- logger.error(
81
- "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
82
- "https://www.tensorflow.org/install/ for installation instructions."
83
- )
84
- raise
85
- tf_path = os.path.abspath(gpt2_checkpoint_path)
86
- logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
87
- # Load weights from TF model
88
- init_vars = tf.train.list_variables(tf_path)
89
- names = []
90
- arrays = []
91
- for name, shape in init_vars:
92
- logger.info(f"Loading TF weight {name} with shape {shape}")
93
- array = tf.train.load_variable(tf_path, name)
94
- names.append(name)
95
- arrays.append(array.squeeze())
96
-
97
- for name, array in zip(names, arrays):
98
- name = name[6:] # skip "model/"
99
- name = name.split("/")
100
- pointer = model
101
- for m_name in name:
102
- if re.fullmatch(r"[A-Za-z]+\d+", m_name):
103
- scope_names = re.split(r"(\d+)", m_name)
104
- else:
105
- scope_names = [m_name]
106
- if scope_names[0] == "w" or scope_names[0] == "g":
107
- pointer = getattr(pointer, "weight")
108
- elif scope_names[0] == "b":
109
- pointer = getattr(pointer, "bias")
110
- elif scope_names[0] == "wpe" or scope_names[0] == "wte":
111
- pointer = getattr(pointer, scope_names[0])
112
- pointer = getattr(pointer, "weight")
113
- else:
114
- pointer = getattr(pointer, scope_names[0])
115
- if len(scope_names) >= 2:
116
- num = int(scope_names[1])
117
- pointer = pointer[num]
118
- try:
119
- if pointer.shape != array.shape:
120
- raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
121
- except ValueError as e:
122
- e.args += (pointer.shape, array.shape)
123
- raise
124
- logger.info(f"Initialize PyTorch weight {name}")
125
- pointer.data = torch.from_numpy(array)
126
- return model
127
-
128
-
129
- class GPT2Attention(nn.Module):
130
- def __init__(self, config, is_cross_attention=False, layer_idx=None):
131
- super().__init__()
132
- self.config = config
133
- max_positions = config.max_position_embeddings
134
- self.register_buffer(
135
- "bias",
136
- torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
137
- 1, 1, max_positions, max_positions
138
- ),
139
- persistent=False,
140
- )
141
- self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
142
-
143
- self.embed_dim = config.hidden_size
144
- self.num_heads = config.num_attention_heads
145
- self.head_dim = self.embed_dim // self.num_heads
146
- self.split_size = self.embed_dim
147
- if self.head_dim * self.num_heads != self.embed_dim:
148
- raise ValueError(
149
- f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
150
- f" {self.num_heads})."
151
- )
152
-
153
- self.scale_attn_weights = config.scale_attn_weights
154
- self.is_cross_attention = is_cross_attention
155
-
156
- # Layer-wise attention scaling, reordering, and upcasting
157
- self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
158
- self.layer_idx = layer_idx
159
- self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
160
-
161
- if self.is_cross_attention:
162
- self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
163
- self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
164
- else:
165
- self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
166
- self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
167
-
168
- self.attn_dropout = nn.Dropout(config.attn_pdrop)
169
- self.resid_dropout = nn.Dropout(config.resid_pdrop)
170
- self.is_causal = True
171
-
172
- self.pruned_heads = set()
173
-
174
- def prune_heads(self, heads):
175
- if len(heads) == 0:
176
- return
177
- heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
178
- index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
179
-
180
- # Prune conv1d layers
181
- self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
182
- self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
183
-
184
- # Update hyper params
185
- self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
186
- self.num_heads = self.num_heads - len(heads)
187
- self.pruned_heads = self.pruned_heads.union(heads)
188
-
189
- def _attn(self, query, key, value, attention_mask=None, head_mask=None):
190
- attn_weights = torch.matmul(query, key.transpose(-1, -2))
191
-
192
- if self.scale_attn_weights:
193
- attn_weights = attn_weights / torch.full(
194
- [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
195
- )
196
-
197
- # Layer-wise attention scaling
198
- if self.scale_attn_by_inverse_layer_idx:
199
- attn_weights = attn_weights / float(self.layer_idx + 1)
200
-
201
- if not self.is_cross_attention:
202
- # if only "normal" attention layer implements causal mask
203
- query_length, key_length = query.size(-2), key.size(-2)
204
- causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
205
- mask_value = torch.finfo(attn_weights.dtype).min
206
- # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
207
- # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
208
- mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
209
- attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
210
-
211
- if attention_mask is not None:
212
- # Apply the attention mask
213
- attn_weights = attn_weights + attention_mask
214
-
215
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
216
-
217
- # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
218
- attn_weights = attn_weights.type(value.dtype)
219
- attn_weights = self.attn_dropout(attn_weights)
220
-
221
- # Mask heads if we want to
222
- if head_mask is not None:
223
- attn_weights = attn_weights * head_mask
224
-
225
- attn_output = torch.matmul(attn_weights, value)
226
-
227
- return attn_output, attn_weights
228
-
229
- def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
230
- # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
231
- bsz, num_heads, q_seq_len, dk = query.size()
232
- _, _, k_seq_len, _ = key.size()
233
-
234
- # Preallocate attn_weights for `baddbmm`
235
- attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
236
-
237
- # Compute Scale Factor
238
- scale_factor = 1.0
239
- if self.scale_attn_weights:
240
- scale_factor /= float(value.size(-1)) ** 0.5
241
-
242
- if self.scale_attn_by_inverse_layer_idx:
243
- scale_factor /= float(self.layer_idx + 1)
244
-
245
- # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
246
- with torch.amp.autocast(query.device.type, enabled=False):
247
- q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
248
- attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
249
- attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
250
-
251
- if not self.is_cross_attention:
252
- # if only "normal" attention layer implements causal mask
253
- query_length, key_length = query.size(-2), key.size(-2)
254
- causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
255
- mask_value = torch.finfo(attn_weights.dtype).min
256
- # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
257
- # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
258
- mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
259
- attn_weights = torch.where(causal_mask, attn_weights, mask_value)
260
-
261
- if attention_mask is not None:
262
- # Apply the attention mask
263
- attn_weights = attn_weights + attention_mask
264
-
265
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
266
-
267
- # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
268
- if attn_weights.dtype != torch.float32:
269
- raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
270
- attn_weights = attn_weights.type(value.dtype)
271
- attn_weights = self.attn_dropout(attn_weights)
272
-
273
- # Mask heads if we want to
274
- if head_mask is not None:
275
- attn_weights = attn_weights * head_mask
276
-
277
- attn_output = torch.matmul(attn_weights, value)
278
-
279
- return attn_output, attn_weights
280
-
281
- def _split_heads(self, tensor, num_heads, attn_head_size):
282
- """
283
- Splits hidden_size dim into attn_head_size and num_heads
284
- """
285
- new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
286
- tensor = tensor.view(new_shape)
287
- return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
288
-
289
- def _merge_heads(self, tensor, num_heads, attn_head_size):
290
- """
291
- Merges attn_head_size dim and num_attn_heads dim into hidden_size
292
- """
293
- tensor = tensor.permute(0, 2, 1, 3).contiguous()
294
- new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
295
- return tensor.view(new_shape)
296
-
297
- def forward(
298
- self,
299
- hidden_states: Optional[Tuple[torch.FloatTensor]],
300
- layer_past: Optional[Tuple[torch.Tensor]] = None,
301
- attention_mask: Optional[torch.FloatTensor] = None,
302
- head_mask: Optional[torch.FloatTensor] = None,
303
- encoder_hidden_states: Optional[torch.Tensor] = None,
304
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
305
- use_cache: Optional[bool] = False,
306
- output_attentions: Optional[bool] = False,
307
- ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
308
- if encoder_hidden_states is not None:
309
- if not hasattr(self, "q_attn"):
310
- raise ValueError(
311
- "If class is used as cross attention, the weights `q_attn` have to be defined. "
312
- "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
313
- )
314
-
315
- query = self.q_attn(hidden_states)
316
- key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
317
- attention_mask = encoder_attention_mask
318
- else:
319
- query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
320
-
321
- query = self._split_heads(query, self.num_heads, self.head_dim)
322
- key = self._split_heads(key, self.num_heads, self.head_dim)
323
- value = self._split_heads(value, self.num_heads, self.head_dim)
324
-
325
- if layer_past is not None:
326
- past_key, past_value = layer_past
327
- key = torch.cat((past_key, key), dim=-2)
328
- value = torch.cat((past_value, value), dim=-2)
329
-
330
- if use_cache is True:
331
- present = (key, value)
332
- else:
333
- present = None
334
-
335
- if self.reorder_and_upcast_attn:
336
- attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
337
- else:
338
- attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
339
-
340
- attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
341
- attn_output = self.c_proj(attn_output)
342
- attn_output = self.resid_dropout(attn_output)
343
-
344
- outputs = (attn_output, present)
345
- if output_attentions:
346
- outputs += (attn_weights,)
347
-
348
- return outputs # a, present, (attentions)
349
-
350
-
351
- class GPT2FlashAttention2(GPT2Attention):
352
- """
353
- GPT2 flash attention module. This module inherits from `GPT2Attention` as the weights of the module stays
354
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
355
- flash attention and deal with padding tokens in case the input contains any of them.
356
- """
357
-
358
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
359
- def __init__(self, *args, **kwargs):
360
- super().__init__(*args, **kwargs)
361
-
362
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
363
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
364
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
365
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
366
-
367
- def forward(
368
- self,
369
- hidden_states: Optional[Tuple[torch.FloatTensor]],
370
- layer_past: Optional[Tuple[torch.Tensor]] = None,
371
- attention_mask: Optional[torch.FloatTensor] = None,
372
- head_mask: Optional[torch.FloatTensor] = None,
373
- encoder_hidden_states: Optional[torch.Tensor] = None,
374
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
375
- use_cache: Optional[bool] = False,
376
- output_attentions: Optional[bool] = False,
377
- ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
378
- bsz, _, _ = hidden_states.size()
379
- if encoder_hidden_states is not None:
380
- if not hasattr(self, "q_attn"):
381
- raise ValueError(
382
- "If class is used as cross attention, the weights `q_attn` have to be defined. "
383
- "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
384
- )
385
-
386
- query = self.q_attn(hidden_states)
387
- key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
388
- attention_mask = encoder_attention_mask
389
- else:
390
- query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
391
-
392
- query = self._split_heads(query, self.num_heads, self.head_dim)
393
- key = self._split_heads(key, self.num_heads, self.head_dim)
394
- value = self._split_heads(value, self.num_heads, self.head_dim)
395
-
396
- if layer_past is not None:
397
- past_key = layer_past[0]
398
- past_value = layer_past[1]
399
- key = torch.cat((past_key, key), dim=-2)
400
- value = torch.cat((past_value, value), dim=-2)
401
-
402
- present = None
403
- if use_cache is True:
404
- present = (key, value)
405
-
406
- query_length = query.shape[2]
407
- tgt_len = key.shape[2]
408
-
409
- # Flash attention requires the input to have the shape
410
- # batch_size x seq_length x head_dim x hidden_dim
411
- query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim)
412
- key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
413
- value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
414
-
415
- attn_dropout = self.attn_dropout.p if self.training else 0.0
416
-
417
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
418
- # therefore the input hidden states gets silently casted in float32. Hence, we need
419
- # cast them back in the correct dtype just to be sure everything works as expected.
420
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
421
- # in fp32. (LlamaRMSNorm handles it correctly)
422
-
423
- if query.dtype == torch.float32:
424
- if torch.is_autocast_enabled():
425
- target_dtype = torch.get_autocast_gpu_dtype()
426
- # Handle the case where the model is quantized
427
- elif hasattr(self.config, "_pre_quantization_dtype"):
428
- target_dtype = self.config._pre_quantization_dtype
429
- else:
430
- target_dtype = self.c_proj.weight.dtype
431
-
432
- logger.warning_once(
433
- f"The input hidden states seems to be silently casted in float32, this might be related to"
434
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
435
- f" {target_dtype}."
436
- )
437
-
438
- query = query.to(target_dtype)
439
- key = key.to(target_dtype)
440
- value = value.to(target_dtype)
441
-
442
- attn_output = _flash_attention_forward(
443
- query,
444
- key,
445
- value,
446
- attention_mask,
447
- query_length,
448
- dropout=attn_dropout,
449
- is_causal=self.is_causal,
450
- use_top_left_mask=self._flash_attn_uses_top_left_mask,
451
- )
452
-
453
- attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim)
454
- attn_output = self.c_proj(attn_weights_reshaped)
455
- attn_output = self.resid_dropout(attn_output)
456
-
457
- outputs = (attn_output, present)
458
- if output_attentions:
459
- outputs += (attn_weights_reshaped,)
460
-
461
- return outputs
462
-
463
-
464
- class GPT2SdpaAttention(GPT2Attention):
465
- """
466
- GPT2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
467
- `GPT2Attention` as the weights of the module stays untouched. The only changes are on the forward pass
468
- to adapt to the SDPA API.
469
- """
470
-
471
- def __init__(self, *args, **kwargs):
472
- super().__init__(*args, **kwargs)
473
-
474
- # Idea adapted from transformers.models.bert.modeling_bert.BertSdpaSelfAttention.__init__
475
- # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
476
- # attn_mask, so we need to call `.contiguous()`. This was fixed in torch==2.2.0.
477
- # Reference: https://github.com/pytorch/pytorch/issues/112577
478
- self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
479
-
480
- def forward(
481
- self,
482
- hidden_states: Optional[Tuple[torch.FloatTensor]],
483
- layer_past: Optional[Tuple[torch.Tensor]] = None,
484
- attention_mask: Optional[torch.FloatTensor] = None,
485
- head_mask: Optional[torch.FloatTensor] = None,
486
- encoder_hidden_states: Optional[torch.Tensor] = None,
487
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
488
- use_cache: Optional[bool] = False,
489
- output_attentions: Optional[bool] = False,
490
- ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
491
- if output_attentions or head_mask is not None:
492
- logger.warning_once(
493
- "`GPT2SdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
494
- "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
495
- "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
496
- 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
497
- )
498
- return super().forward(
499
- hidden_states=hidden_states,
500
- layer_past=layer_past,
501
- attention_mask=attention_mask,
502
- head_mask=head_mask,
503
- encoder_hidden_states=encoder_hidden_states,
504
- encoder_attention_mask=encoder_attention_mask,
505
- use_cache=use_cache,
506
- output_attentions=output_attentions,
507
- )
508
-
509
- bsz, q_len, _ = hidden_states.size()
510
-
511
- # Initial attention projections
512
- is_cross_attention = encoder_hidden_states is not None
513
- if is_cross_attention:
514
- if not hasattr(self, "q_attn"):
515
- raise ValueError(
516
- "If class is used as cross attention, the weights `q_attn` have to be defined. "
517
- "Please make sure to instantiate class with `GPT2SdpaAttention(..., is_cross_attention=True)`."
518
- )
519
-
520
- query = self.q_attn(hidden_states)
521
- key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
522
- attention_mask = encoder_attention_mask
523
- else:
524
- query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
525
-
526
- query = self._split_heads(query, self.num_heads, self.head_dim)
527
- key = self._split_heads(key, self.num_heads, self.head_dim)
528
- value = self._split_heads(value, self.num_heads, self.head_dim)
529
-
530
- # Optional kv caching
531
- if layer_past is not None:
532
- past_key = layer_past[0]
533
- past_value = layer_past[1]
534
- key = torch.cat((past_key, key), dim=-2)
535
- value = torch.cat((past_value, value), dim=-2)
536
-
537
- present = None
538
- if use_cache is True:
539
- present = (key, value)
540
-
541
- # Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA
542
- if self.require_contiguous_qkv and query.device.type == "cuda" and attention_mask is not None:
543
- query = query.contiguous()
544
- key = key.contiguous()
545
- value = value.contiguous()
546
-
547
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
548
- # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
549
- is_causal = True if attention_mask is None and q_len > 1 and not is_cross_attention else False
550
-
551
- attn_output = torch.nn.functional.scaled_dot_product_attention(
552
- query,
553
- key,
554
- value,
555
- attn_mask=attention_mask,
556
- dropout_p=self.attn_dropout.p if self.training else 0.0,
557
- is_causal=is_causal,
558
- )
559
-
560
- # Reshape outputs
561
- attn_output = attn_output.transpose(1, 2).contiguous()
562
- attn_output = attn_output.view(bsz, q_len, self.embed_dim)
563
-
564
- # Final projection
565
- attn_output = self.c_proj(attn_output)
566
- attn_output = self.resid_dropout(attn_output)
567
-
568
- return attn_output, present, None
569
-
570
-
571
- class GPT2MLP(nn.Module):
572
- def __init__(self, intermediate_size, config):
573
- super().__init__()
574
- embed_dim = config.hidden_size
575
- self.c_fc = Conv1D(intermediate_size, embed_dim)
576
- self.c_proj = Conv1D(embed_dim, intermediate_size)
577
- self.act = ACT2FN[config.activation_function]
578
- self.dropout = nn.Dropout(config.resid_pdrop)
579
-
580
- def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
581
- hidden_states = self.c_fc(hidden_states)
582
- hidden_states = self.act(hidden_states)
583
- hidden_states = self.c_proj(hidden_states)
584
- hidden_states = self.dropout(hidden_states)
585
- return hidden_states
586
-
587
-
588
- GPT2_ATTENTION_CLASSES = {"eager": GPT2Attention, "flash_attention_2": GPT2FlashAttention2, "sdpa": GPT2SdpaAttention}
589
-
590
-
591
- class GPT2Block(nn.Module):
592
- def __init__(self, config, layer_idx=None):
593
- super().__init__()
594
- hidden_size = config.hidden_size
595
- inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
596
- attention_class = GPT2_ATTENTION_CLASSES[config._attn_implementation]
597
-
598
- self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
599
- self.attn = attention_class(config=config, layer_idx=layer_idx)
600
- self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
601
-
602
- if config.add_cross_attention:
603
- self.crossattention = attention_class(config=config, is_cross_attention=True, layer_idx=layer_idx)
604
- self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
605
-
606
- self.mlp = GPT2MLP(inner_dim, config)
607
-
608
- def forward(
609
- self,
610
- hidden_states: Optional[Tuple[torch.FloatTensor]],
611
- layer_past: Optional[Tuple[torch.Tensor]] = None,
612
- attention_mask: Optional[torch.FloatTensor] = None,
613
- head_mask: Optional[torch.FloatTensor] = None,
614
- encoder_hidden_states: Optional[torch.Tensor] = None,
615
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
616
- use_cache: Optional[bool] = False,
617
- output_attentions: Optional[bool] = False,
618
- ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
619
- residual = hidden_states
620
- hidden_states = self.ln_1(hidden_states)
621
- attn_outputs = self.attn(
622
- hidden_states,
623
- layer_past=layer_past,
624
- attention_mask=attention_mask,
625
- head_mask=head_mask,
626
- use_cache=use_cache,
627
- output_attentions=output_attentions,
628
- )
629
- attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
630
- outputs = attn_outputs[1:]
631
- # residual connection
632
- hidden_states = attn_output + residual
633
-
634
- if encoder_hidden_states is not None:
635
- # add one self-attention block for cross-attention
636
- if not hasattr(self, "crossattention"):
637
- raise ValueError(
638
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
639
- "cross-attention layers by setting `config.add_cross_attention=True`"
640
- )
641
- residual = hidden_states
642
- hidden_states = self.ln_cross_attn(hidden_states)
643
- cross_attn_outputs = self.crossattention(
644
- hidden_states,
645
- attention_mask=attention_mask,
646
- head_mask=head_mask,
647
- encoder_hidden_states=encoder_hidden_states,
648
- encoder_attention_mask=encoder_attention_mask,
649
- output_attentions=output_attentions,
650
- )
651
- attn_output = cross_attn_outputs[0]
652
- # residual connection
653
- hidden_states = residual + attn_output
654
- outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
655
-
656
- residual = hidden_states
657
- hidden_states = self.ln_2(hidden_states)
658
- feed_forward_hidden_states = self.mlp(hidden_states)
659
- # residual connection
660
- hidden_states = residual + feed_forward_hidden_states
661
-
662
- if use_cache:
663
- outputs = (hidden_states,) + outputs
664
- else:
665
- outputs = (hidden_states,) + outputs[1:]
666
-
667
- return outputs # hidden_states, present, (attentions, cross_attentions)
668
-
669
-
670
- class GPT2PreTrainedModel(PreTrainedModel):
671
- """
672
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
673
- models.
674
- """
675
-
676
- config_class = GPT2Config
677
- load_tf_weights = load_tf_weights_in_gpt2
678
- base_model_prefix = "transformer"
679
- is_parallelizable = True
680
- supports_gradient_checkpointing = True
681
- _no_split_modules = ["GPT2Block"]
682
- _skip_keys_device_placement = "past_key_values"
683
- _supports_flash_attn_2 = True
684
- _supports_sdpa = True
685
-
686
- def __init__(self, *inputs, **kwargs):
687
- super().__init__(*inputs, **kwargs)
688
-
689
- def _init_weights(self, module):
690
- """Initialize the weights."""
691
- if isinstance(module, (nn.Linear, Conv1D)):
692
- # Slightly different from the TF version which uses truncated_normal for initialization
693
- # cf https://github.com/pytorch/pytorch/pull/5617
694
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
695
- if module.bias is not None:
696
- module.bias.data.zero_()
697
- elif isinstance(module, nn.Embedding):
698
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
699
- if module.padding_idx is not None:
700
- module.weight.data[module.padding_idx].zero_()
701
- elif isinstance(module, nn.LayerNorm):
702
- module.bias.data.zero_()
703
- module.weight.data.fill_(1.0)
704
-
705
- # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
706
- # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
707
- # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
708
- # > -- GPT-2 :: https://openai.com/blog/better-language-models/
709
- #
710
- # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
711
- for name, p in module.named_parameters():
712
- if name == "c_proj.weight":
713
- # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
714
- p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
715
-
716
-
717
- @dataclass
718
- class GPT2DoubleHeadsModelOutput(ModelOutput):
719
- """
720
- Base class for outputs of models predicting if two sentences are consecutive or not.
721
-
722
- Args:
723
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
724
- Language modeling loss.
725
- mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided):
726
- Multiple choice classification loss.
727
- logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
728
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
729
- mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
730
- Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
731
- past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
732
- Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads,
733
- sequence_length, embed_size_per_head)`).
734
-
735
- Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
736
- `past_key_values` input) to speed up sequential decoding.
737
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
738
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
739
- shape `(batch_size, sequence_length, hidden_size)`.
740
-
741
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
742
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
743
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
744
- sequence_length)`.
745
-
746
- GPT2Attentions weights after the attention softmax, used to compute the weighted average in the
747
- self-attention heads.
748
- """
749
-
750
- loss: Optional[torch.FloatTensor] = None
751
- mc_loss: Optional[torch.FloatTensor] = None
752
- logits: torch.FloatTensor = None
753
- mc_logits: torch.FloatTensor = None
754
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
755
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
756
- attentions: Optional[Tuple[torch.FloatTensor]] = None
757
-
758
-
759
- GPT2_START_DOCSTRING = r"""
760
-
761
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
762
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
763
- etc.)
764
-
765
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
766
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
767
- and behavior.
768
-
769
- Parameters:
770
- config ([`GPT2Config`]): Model configuration class with all the parameters of the model.
771
- Initializing with a config file does not load the weights associated with the model, only the
772
- configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
773
- """
774
-
775
- GPT2_INPUTS_DOCSTRING = r"""
776
- Args:
777
- input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
778
- `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
779
- `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
780
- sequence tokens in the vocabulary.
781
-
782
- If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
783
- `input_ids`.
784
-
785
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
786
- [`PreTrainedTokenizer.__call__`] for details.
787
-
788
- [What are input IDs?](../glossary#input-ids)
789
- past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
790
- Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
791
- `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
792
- their past given to this model should not be passed as `input_ids` as they have already been computed.
793
- attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
794
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
795
-
796
- - 1 for tokens that are **not masked**,
797
- - 0 for tokens that are **masked**.
798
-
799
- If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
800
- `past_key_values`. In other words, the `attention_mask` always has to have the length:
801
- `len(past_key_values) + len(input_ids)`
802
-
803
- [What are attention masks?](../glossary#attention-mask)
804
- token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
805
- Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
806
- 1]`:
807
-
808
- - 0 corresponds to a *sentence A* token,
809
- - 1 corresponds to a *sentence B* token.
810
-
811
- [What are token type IDs?](../glossary#token-type-ids)
812
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
813
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
814
- config.max_position_embeddings - 1]`.
815
-
816
- [What are position IDs?](../glossary#position-ids)
817
- head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
818
- Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
819
-
820
- - 1 indicates the head is **not masked**,
821
- - 0 indicates the head is **masked**.
822
-
823
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
824
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
825
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
826
- model's internal embedding lookup matrix.
827
-
828
- If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
829
- `past_key_values`).
830
- use_cache (`bool`, *optional*):
831
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
832
- `past_key_values`).
833
- output_attentions (`bool`, *optional*):
834
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
835
- tensors for more detail.
836
- output_hidden_states (`bool`, *optional*):
837
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
838
- more detail.
839
- return_dict (`bool`, *optional*):
840
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
841
- """
842
- PARALLELIZE_DOCSTRING = r"""
843
- This is an experimental feature and is a subject to change at a moment's notice.
844
-
845
- Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
846
- it will evenly distribute blocks across all devices.
847
-
848
- Args:
849
- device_map (`Dict[int, list]`, *optional*):
850
- A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
851
- automatically mapped to the first device (for esoteric reasons). That means that the first device should
852
- have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
853
- following number of attention modules:
854
-
855
- - openai-community/gpt2: 12
856
- - openai-community/gpt2-medium: 24
857
- - openai-community/gpt2-large: 36
858
- - openai-community/gpt2-xl: 48
859
-
860
- Example:
861
-
862
- ```python
863
- # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules:
864
- model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-xl")
865
- device_map = {
866
- 0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
867
- 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
868
- 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
869
- 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
870
- }
871
- model.parallelize(device_map)
872
- ```
873
- """
874
- DEPARALLELIZE_DOCSTRING = r"""
875
- Moves the model to cpu from a model parallel state.
876
-
877
- Example:
878
-
879
- ```python
880
- # On a 4 GPU machine with openai-community/gpt2-large:
881
- model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-large")
882
- device_map = {
883
- 0: [0, 1, 2, 3, 4, 5, 6, 7],
884
- 1: [8, 9, 10, 11, 12, 13, 14, 15],
885
- 2: [16, 17, 18, 19, 20, 21, 22, 23],
886
- 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
887
- }
888
- model.parallelize(device_map) # Splits the model across several devices
889
- model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
890
- ```
891
- """
892
-
893
-
894
- @add_start_docstrings(
895
- "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
896
- GPT2_START_DOCSTRING,
897
- )
898
- class GPT2Model(GPT2PreTrainedModel):
899
- _supports_param_buffer_assignment = False
900
-
901
- def __init__(self, config):
902
- super().__init__(config)
903
-
904
- self.embed_dim = config.hidden_size
905
-
906
- self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
907
- self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
908
-
909
- self.drop = nn.Dropout(config.embd_pdrop)
910
- self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
911
- self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
912
-
913
- # Model parallel
914
- self.model_parallel = False
915
- self.device_map = None
916
- self.gradient_checkpointing = False
917
- self._attn_implementation = config._attn_implementation
918
-
919
- # Initialize weights and apply final processing
920
- self.post_init()
921
-
922
- @add_start_docstrings(PARALLELIZE_DOCSTRING)
923
- def parallelize(self, device_map=None):
924
- # Check validity of device_map
925
- warnings.warn(
926
- "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your"
927
- " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
928
- " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1,"
929
- " ...}",
930
- FutureWarning,
931
- )
932
- self.device_map = (
933
- get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
934
- )
935
- assert_device_map(self.device_map, len(self.h))
936
- self.model_parallel = True
937
- self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
938
- self.last_device = "cuda:" + str(max(self.device_map.keys()))
939
- self.wte = self.wte.to(self.first_device)
940
- self.wpe = self.wpe.to(self.first_device)
941
- # Load onto devices
942
- for k, v in self.device_map.items():
943
- for block in v:
944
- cuda_device = "cuda:" + str(k)
945
- self.h[block] = self.h[block].to(cuda_device)
946
- # ln_f to last
947
- self.ln_f = self.ln_f.to(self.last_device)
948
-
949
- @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
950
- def deparallelize(self):
951
- warnings.warn(
952
- "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
953
- FutureWarning,
954
- )
955
- self.model_parallel = False
956
- self.device_map = None
957
- self.first_device = "cpu"
958
- self.last_device = "cpu"
959
- self.wte = self.wte.to("cpu")
960
- self.wpe = self.wpe.to("cpu")
961
- for index in range(len(self.h)):
962
- self.h[index] = self.h[index].to("cpu")
963
- self.ln_f = self.ln_f.to("cpu")
964
- torch.cuda.empty_cache()
965
-
966
- def get_input_embeddings(self):
967
- return self.wte
968
-
969
- def set_input_embeddings(self, new_embeddings):
970
- self.wte = new_embeddings
971
-
972
- def _prune_heads(self, heads_to_prune):
973
- """
974
- Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
975
- """
976
- for layer, heads in heads_to_prune.items():
977
- self.h[layer].attn.prune_heads(heads)
978
-
979
- @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
980
- @add_code_sample_docstrings(
981
- checkpoint=_CHECKPOINT_FOR_DOC,
982
- output_type=BaseModelOutputWithPastAndCrossAttentions,
983
- config_class=_CONFIG_FOR_DOC,
984
- )
985
- def forward(
986
- self,
987
- input_ids: Optional[torch.LongTensor] = None,
988
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
989
- attention_mask: Optional[torch.FloatTensor] = None,
990
- token_type_ids: Optional[torch.LongTensor] = None,
991
- position_ids: Optional[torch.LongTensor] = None,
992
- head_mask: Optional[torch.FloatTensor] = None,
993
- inputs_embeds: Optional[torch.FloatTensor] = None,
994
- encoder_hidden_states: Optional[torch.Tensor] = None,
995
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
996
- use_cache: Optional[bool] = None,
997
- output_attentions: Optional[bool] = None,
998
- output_hidden_states: Optional[bool] = None,
999
- return_dict: Optional[bool] = None,
1000
- ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
1001
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1002
- output_hidden_states = (
1003
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1004
- )
1005
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1006
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1007
-
1008
- if input_ids is not None and inputs_embeds is not None:
1009
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1010
- elif input_ids is not None:
1011
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1012
- input_shape = input_ids.size()
1013
- input_ids = input_ids.view(-1, input_shape[-1])
1014
- batch_size = input_ids.shape[0]
1015
- elif inputs_embeds is not None:
1016
- input_shape = inputs_embeds.size()[:-1]
1017
- batch_size = inputs_embeds.shape[0]
1018
- else:
1019
- raise ValueError("You have to specify either input_ids or inputs_embeds")
1020
-
1021
- device = input_ids.device if input_ids is not None else inputs_embeds.device
1022
-
1023
- if token_type_ids is not None:
1024
- token_type_ids = token_type_ids.view(-1, input_shape[-1])
1025
-
1026
- if past_key_values is None:
1027
- past_length = 0
1028
- past_key_values = tuple([None] * len(self.h))
1029
- else:
1030
- past_length = past_key_values[0][0].size(-2)
1031
- if position_ids is None:
1032
- position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
1033
- position_ids = position_ids.unsqueeze(0)
1034
-
1035
- if inputs_embeds is None:
1036
- inputs_embeds = self.wte(input_ids)
1037
- position_embeds = self.wpe(position_ids)
1038
- hidden_states = inputs_embeds + position_embeds
1039
-
1040
- # Attention mask.
1041
- _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
1042
- attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None
1043
- if self._attn_implementation == "flash_attention_2":
1044
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1045
- elif _use_sdpa:
1046
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1047
- attention_mask=attention_mask,
1048
- input_shape=(batch_size, input_shape[-1]),
1049
- inputs_embeds=inputs_embeds,
1050
- past_key_values_length=past_length,
1051
- )
1052
- else:
1053
- if attention_mask is not None:
1054
- # We create a 3D attention mask from a 2D tensor mask.
1055
- # Sizes are [batch_size, 1, 1, to_seq_length]
1056
- # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
1057
- # this attention mask is more simple than the triangular masking of causal attention
1058
- # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
1059
- attention_mask = attention_mask[:, None, None, :]
1060
-
1061
- # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
1062
- # masked positions, this operation will create a tensor which is 0.0 for
1063
- # positions we want to attend and the dtype's smallest value for masked positions.
1064
- # Since we are adding it to the raw scores before the softmax, this is
1065
- # effectively the same as removing these entirely.
1066
- attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
1067
- attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
1068
-
1069
- # If a 2D or 3D attention mask is provided for the cross-attention
1070
- # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1071
- if self.config.add_cross_attention and encoder_hidden_states is not None:
1072
- encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
1073
- encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1074
- if encoder_attention_mask is None:
1075
- encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
1076
- if _use_sdpa:
1077
- encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
1078
- mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1]
1079
- )
1080
- elif not self._attn_implementation == "flash_attention_2":
1081
- encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1082
- else:
1083
- encoder_attention_mask = None
1084
-
1085
- # Prepare head mask if needed
1086
- # 1.0 in head_mask indicate we keep the head
1087
- # attention_probs has shape bsz x n_heads x N x N
1088
- # head_mask has shape n_layer x batch x n_heads x N x N
1089
- head_mask = self.get_head_mask(head_mask, self.config.n_layer)
1090
-
1091
- if token_type_ids is not None:
1092
- token_type_embeds = self.wte(token_type_ids)
1093
- hidden_states = hidden_states + token_type_embeds
1094
-
1095
- hidden_states = self.drop(hidden_states)
1096
-
1097
- output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
1098
-
1099
- if self.gradient_checkpointing and self.training:
1100
- if use_cache:
1101
- logger.warning_once(
1102
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1103
- )
1104
- use_cache = False
1105
-
1106
- presents = () if use_cache else None
1107
- all_self_attentions = () if output_attentions else None
1108
- all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
1109
- all_hidden_states = () if output_hidden_states else None
1110
- for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
1111
- # Model parallel
1112
- if self.model_parallel:
1113
- torch.cuda.set_device(hidden_states.device)
1114
- # Ensure layer_past is on same device as hidden_states (might not be correct)
1115
- if layer_past is not None:
1116
- layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
1117
- # Ensure that attention_mask is always on the same device as hidden_states
1118
- if attention_mask is not None:
1119
- attention_mask = attention_mask.to(hidden_states.device)
1120
- if isinstance(head_mask, torch.Tensor):
1121
- head_mask = head_mask.to(hidden_states.device)
1122
- if output_hidden_states:
1123
- all_hidden_states = all_hidden_states + (hidden_states,)
1124
-
1125
- if self.gradient_checkpointing and self.training:
1126
- outputs = self._gradient_checkpointing_func(
1127
- block.__call__,
1128
- hidden_states,
1129
- None,
1130
- attention_mask,
1131
- head_mask[i],
1132
- encoder_hidden_states,
1133
- encoder_attention_mask,
1134
- use_cache,
1135
- output_attentions,
1136
- )
1137
- else:
1138
- outputs = block(
1139
- hidden_states,
1140
- layer_past=layer_past,
1141
- attention_mask=attention_mask,
1142
- head_mask=head_mask[i],
1143
- encoder_hidden_states=encoder_hidden_states,
1144
- encoder_attention_mask=encoder_attention_mask,
1145
- use_cache=use_cache,
1146
- output_attentions=output_attentions,
1147
- )
1148
-
1149
- hidden_states = outputs[0]
1150
- if use_cache is True:
1151
- presents = presents + (outputs[1],)
1152
-
1153
- if output_attentions:
1154
- all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
1155
- if self.config.add_cross_attention:
1156
- all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
1157
-
1158
- # Model Parallel: If it's the last layer for that device, put things on the next device
1159
- if self.model_parallel:
1160
- for k, v in self.device_map.items():
1161
- if i == v[-1] and "cuda:" + str(k) != self.last_device:
1162
- hidden_states = hidden_states.to("cuda:" + str(k + 1))
1163
-
1164
- hidden_states = self.ln_f(hidden_states)
1165
-
1166
- hidden_states = hidden_states.view(output_shape)
1167
- # Add last hidden state
1168
- if output_hidden_states:
1169
- all_hidden_states = all_hidden_states + (hidden_states,)
1170
-
1171
- if not return_dict:
1172
- return tuple(
1173
- v
1174
- for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
1175
- if v is not None
1176
- )
1177
-
1178
- return BaseModelOutputWithPastAndCrossAttentions(
1179
- last_hidden_state=hidden_states,
1180
- past_key_values=presents,
1181
- hidden_states=all_hidden_states,
1182
- attentions=all_self_attentions,
1183
- cross_attentions=all_cross_attentions,
1184
- )
1185
-
1186
-
1187
- @add_start_docstrings(
1188
- """
1189
- The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
1190
- embeddings).
1191
- """,
1192
- GPT2_START_DOCSTRING,
1193
- )
1194
- class GPT2LMHeadModel(GPT2PreTrainedModel, GenerationMixin):
1195
- _tied_weights_keys = ["lm_head.weight"]
1196
-
1197
- def __init__(self, config):
1198
- super().__init__(config)
1199
- self.transformer = GPT2Model(config)
1200
- self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1201
-
1202
- # Model parallel
1203
- self.model_parallel = False
1204
- self.device_map = None
1205
-
1206
- # Initialize weights and apply final processing
1207
- self.post_init()
1208
-
1209
- @add_start_docstrings(PARALLELIZE_DOCSTRING)
1210
- def parallelize(self, device_map=None):
1211
- warnings.warn(
1212
- "`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
1213
- " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
1214
- " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':"
1215
- " 0, 'transformer.h.1': 1, ...}",
1216
- FutureWarning,
1217
- )
1218
- self.device_map = (
1219
- get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
1220
- if device_map is None
1221
- else device_map
1222
- )
1223
- assert_device_map(self.device_map, len(self.transformer.h))
1224
- self.transformer.parallelize(self.device_map)
1225
- self.lm_head = self.lm_head.to(self.transformer.first_device)
1226
- self.model_parallel = True
1227
-
1228
- @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1229
- def deparallelize(self):
1230
- warnings.warn(
1231
- "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1232
- FutureWarning,
1233
- )
1234
- self.transformer.deparallelize()
1235
- self.transformer = self.transformer.to("cpu")
1236
- self.lm_head = self.lm_head.to("cpu")
1237
- self.model_parallel = False
1238
- torch.cuda.empty_cache()
1239
-
1240
- def get_output_embeddings(self):
1241
- return self.lm_head
1242
-
1243
- def set_output_embeddings(self, new_embeddings):
1244
- self.lm_head = new_embeddings
1245
-
1246
- @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1247
- @add_code_sample_docstrings(
1248
- checkpoint=_CHECKPOINT_FOR_DOC,
1249
- output_type=CausalLMOutputWithCrossAttentions,
1250
- config_class=_CONFIG_FOR_DOC,
1251
- )
1252
- def forward(
1253
- self,
1254
- input_ids: Optional[torch.LongTensor] = None,
1255
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1256
- attention_mask: Optional[torch.FloatTensor] = None,
1257
- token_type_ids: Optional[torch.LongTensor] = None,
1258
- position_ids: Optional[torch.LongTensor] = None,
1259
- head_mask: Optional[torch.FloatTensor] = None,
1260
- inputs_embeds: Optional[torch.FloatTensor] = None,
1261
- encoder_hidden_states: Optional[torch.Tensor] = None,
1262
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
1263
- labels: Optional[torch.LongTensor] = None,
1264
- use_cache: Optional[bool] = None,
1265
- output_attentions: Optional[bool] = None,
1266
- output_hidden_states: Optional[bool] = None,
1267
- return_dict: Optional[bool] = None,
1268
- ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1269
- r"""
1270
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1271
- Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1272
- `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1273
- are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1274
- """
1275
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1276
-
1277
- transformer_outputs = self.transformer(
1278
- input_ids,
1279
- past_key_values=past_key_values,
1280
- attention_mask=attention_mask,
1281
- token_type_ids=token_type_ids,
1282
- position_ids=position_ids,
1283
- head_mask=head_mask,
1284
- inputs_embeds=inputs_embeds,
1285
- encoder_hidden_states=encoder_hidden_states,
1286
- encoder_attention_mask=encoder_attention_mask,
1287
- use_cache=use_cache,
1288
- output_attentions=output_attentions,
1289
- output_hidden_states=output_hidden_states,
1290
- return_dict=return_dict,
1291
- )
1292
- hidden_states = transformer_outputs[0]
1293
-
1294
- # Set device for model parallelism
1295
- if self.model_parallel:
1296
- torch.cuda.set_device(self.transformer.first_device)
1297
- hidden_states = hidden_states.to(self.lm_head.weight.device)
1298
-
1299
- lm_logits = self.lm_head(hidden_states)
1300
-
1301
- loss = None
1302
- if labels is not None:
1303
- # move labels to correct device to enable model parallelism
1304
- labels = labels.to(lm_logits.device)
1305
- # Shift so that tokens < n predict n
1306
- shift_logits = lm_logits[..., :-1, :].contiguous()
1307
- shift_labels = labels[..., 1:].contiguous()
1308
- # Flatten the tokens
1309
- loss_fct = CrossEntropyLoss()
1310
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1311
-
1312
- if not return_dict:
1313
- output = (lm_logits,) + transformer_outputs[1:]
1314
- return ((loss,) + output) if loss is not None else output
1315
-
1316
- return CausalLMOutputWithCrossAttentions(
1317
- loss=loss,
1318
- logits=lm_logits,
1319
- past_key_values=transformer_outputs.past_key_values,
1320
- hidden_states=transformer_outputs.hidden_states,
1321
- attentions=transformer_outputs.attentions,
1322
- cross_attentions=transformer_outputs.cross_attentions,
1323
- )
1324
-
1325
- @staticmethod
1326
- def _reorder_cache(
1327
- past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1328
- ) -> Tuple[Tuple[torch.Tensor]]:
1329
- """
1330
- This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1331
- [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1332
- beam_idx at every generation step.
1333
- """
1334
- return tuple(
1335
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
1336
- for layer_past in past_key_values
1337
- )
1338
-
1339
-
1340
- @add_start_docstrings(
1341
- """
1342
- The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
1343
- RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
1344
- input embeddings, the classification head takes as input the input of a specified classification token index in the
1345
- input sequence).
1346
- """,
1347
- GPT2_START_DOCSTRING,
1348
- )
1349
- class GPT2DoubleHeadsModel(GPT2PreTrainedModel, GenerationMixin):
1350
- _tied_weights_keys = ["lm_head.weight"]
1351
-
1352
- def __init__(self, config):
1353
- super().__init__(config)
1354
- config.num_labels = 1
1355
- self.transformer = GPT2Model(config)
1356
- self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1357
- self.multiple_choice_head = SequenceSummary(config)
1358
-
1359
- # Model parallel
1360
- self.model_parallel = False
1361
- self.device_map = None
1362
-
1363
- # Initialize weights and apply final processing
1364
- self.post_init()
1365
-
1366
- @add_start_docstrings(PARALLELIZE_DOCSTRING)
1367
- def parallelize(self, device_map=None):
1368
- warnings.warn(
1369
- "`GPT2DoubleHeadsModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should"
1370
- " load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your"
1371
- " own `device_map` but it needs to be a dictionary module_name to device, so for instance"
1372
- " {'transformer.h.0': 0, 'transformer.h.1': 1, ...}",
1373
- FutureWarning,
1374
- )
1375
- self.device_map = (
1376
- get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
1377
- if device_map is None
1378
- else device_map
1379
- )
1380
- assert_device_map(self.device_map, len(self.transformer.h))
1381
- self.transformer.parallelize(self.device_map)
1382
- self.lm_head = self.lm_head.to(self.transformer.first_device)
1383
- self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device)
1384
- self.model_parallel = True
1385
-
1386
- @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1387
- def deparallelize(self):
1388
- warnings.warn(
1389
- "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1390
- FutureWarning,
1391
- )
1392
- self.transformer.deparallelize()
1393
- self.transformer = self.transformer.to("cpu")
1394
- self.lm_head = self.lm_head.to("cpu")
1395
- self.multiple_choice_head = self.multiple_choice_head.to("cpu")
1396
- self.model_parallel = False
1397
- torch.cuda.empty_cache()
1398
-
1399
- def get_output_embeddings(self):
1400
- return self.lm_head
1401
-
1402
- def set_output_embeddings(self, new_embeddings):
1403
- self.lm_head = new_embeddings
1404
-
1405
- @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1406
- @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
1407
- def forward(
1408
- self,
1409
- input_ids: Optional[torch.LongTensor] = None,
1410
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1411
- attention_mask: Optional[torch.FloatTensor] = None,
1412
- token_type_ids: Optional[torch.LongTensor] = None,
1413
- position_ids: Optional[torch.LongTensor] = None,
1414
- head_mask: Optional[torch.FloatTensor] = None,
1415
- inputs_embeds: Optional[torch.FloatTensor] = None,
1416
- mc_token_ids: Optional[torch.LongTensor] = None,
1417
- labels: Optional[torch.LongTensor] = None,
1418
- mc_labels: Optional[torch.LongTensor] = None,
1419
- use_cache: Optional[bool] = None,
1420
- output_attentions: Optional[bool] = None,
1421
- output_hidden_states: Optional[bool] = None,
1422
- return_dict: Optional[bool] = None,
1423
- **kwargs,
1424
- ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]:
1425
- r"""
1426
- mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
1427
- Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
1428
- 1]`.
1429
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1430
- Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1431
- `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to
1432
- `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]`
1433
- mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*):
1434
- Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
1435
- where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above)
1436
-
1437
- Return:
1438
-
1439
- Example:
1440
-
1441
- ```python
1442
- >>> import torch
1443
- >>> from transformers import AutoTokenizer, GPT2DoubleHeadsModel
1444
-
1445
- >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
1446
- >>> model = GPT2DoubleHeadsModel.from_pretrained("openai-community/gpt2")
1447
-
1448
- >>> # Add a [CLS] to the vocabulary (we should train it also!)
1449
- >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"})
1450
- >>> # Update the model embeddings with the new vocabulary size
1451
- >>> embedding_layer = model.resize_token_embeddings(len(tokenizer))
1452
-
1453
- >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
1454
- >>> encoded_choices = [tokenizer.encode(s) for s in choices]
1455
- >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
1456
-
1457
- >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2
1458
- >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1
1459
-
1460
- >>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
1461
- >>> lm_logits = outputs.logits
1462
- >>> mc_logits = outputs.mc_logits
1463
- ```"""
1464
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1465
-
1466
- transformer_outputs = self.transformer(
1467
- input_ids,
1468
- past_key_values=past_key_values,
1469
- attention_mask=attention_mask,
1470
- token_type_ids=token_type_ids,
1471
- position_ids=position_ids,
1472
- head_mask=head_mask,
1473
- inputs_embeds=inputs_embeds,
1474
- use_cache=use_cache,
1475
- output_attentions=output_attentions,
1476
- output_hidden_states=output_hidden_states,
1477
- return_dict=return_dict,
1478
- )
1479
-
1480
- hidden_states = transformer_outputs[0]
1481
-
1482
- # Set device for model parallelism
1483
- if self.model_parallel:
1484
- torch.cuda.set_device(self.transformer.first_device)
1485
- hidden_states = hidden_states.to(self.lm_head.weight.device)
1486
-
1487
- lm_logits = self.lm_head(hidden_states)
1488
- mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
1489
-
1490
- mc_loss = None
1491
- if mc_labels is not None:
1492
- loss_fct = CrossEntropyLoss()
1493
- mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
1494
- lm_loss = None
1495
- if labels is not None:
1496
- labels = labels.to(lm_logits.device)
1497
- shift_logits = lm_logits[..., :-1, :].contiguous()
1498
- shift_labels = labels[..., 1:].contiguous()
1499
- loss_fct = CrossEntropyLoss()
1500
- lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1501
-
1502
- if not return_dict:
1503
- output = (lm_logits, mc_logits) + transformer_outputs[1:]
1504
- if mc_loss is not None:
1505
- output = (mc_loss,) + output
1506
- return ((lm_loss,) + output) if lm_loss is not None else output
1507
-
1508
- return GPT2DoubleHeadsModelOutput(
1509
- loss=lm_loss,
1510
- mc_loss=mc_loss,
1511
- logits=lm_logits,
1512
- mc_logits=mc_logits,
1513
- past_key_values=transformer_outputs.past_key_values,
1514
- hidden_states=transformer_outputs.hidden_states,
1515
- attentions=transformer_outputs.attentions,
1516
- )
1517
-
1518
- @staticmethod
1519
- def _reorder_cache(
1520
- past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1521
- ) -> Tuple[Tuple[torch.Tensor]]:
1522
- """
1523
- This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1524
- [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1525
- beam_idx at every generation step.
1526
- """
1527
- return tuple(
1528
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
1529
- for layer_past in past_key_values
1530
- )
1531
-
1532
-
1533
- @add_start_docstrings(
1534
- """
1535
- The GPT2 Model transformer with a sequence classification head on top (linear layer).
1536
-
1537
- [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1538
- (e.g. GPT-1) do.
1539
-
1540
- Since it does classification on the last token, it requires to know the position of the last token. If a
1541
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1542
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1543
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1544
- each row of the batch).
1545
- """,
1546
- GPT2_START_DOCSTRING,
1547
- )
1548
- class GPT2ForSequenceClassification(GPT2PreTrainedModel):
1549
- def __init__(self, config):
1550
- super().__init__(config)
1551
- self.num_labels = config.num_labels
1552
- self.transformer = GPT2Model(config)
1553
- self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
1554
-
1555
- # Model parallel
1556
- self.model_parallel = False
1557
- self.device_map = None
1558
-
1559
- # Initialize weights and apply final processing
1560
- self.post_init()
1561
-
1562
- @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1563
- @add_code_sample_docstrings(
1564
- checkpoint="microsoft/DialogRPT-updown",
1565
- output_type=SequenceClassifierOutputWithPast,
1566
- config_class=_CONFIG_FOR_DOC,
1567
- )
1568
- def forward(
1569
- self,
1570
- input_ids: Optional[torch.LongTensor] = None,
1571
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1572
- attention_mask: Optional[torch.FloatTensor] = None,
1573
- token_type_ids: Optional[torch.LongTensor] = None,
1574
- position_ids: Optional[torch.LongTensor] = None,
1575
- head_mask: Optional[torch.FloatTensor] = None,
1576
- inputs_embeds: Optional[torch.FloatTensor] = None,
1577
- labels: Optional[torch.LongTensor] = None,
1578
- use_cache: Optional[bool] = None,
1579
- output_attentions: Optional[bool] = None,
1580
- output_hidden_states: Optional[bool] = None,
1581
- return_dict: Optional[bool] = None,
1582
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1583
- r"""
1584
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1585
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1586
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1587
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1588
- """
1589
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1590
-
1591
- transformer_outputs = self.transformer(
1592
- input_ids,
1593
- past_key_values=past_key_values,
1594
- attention_mask=attention_mask,
1595
- token_type_ids=token_type_ids,
1596
- position_ids=position_ids,
1597
- head_mask=head_mask,
1598
- inputs_embeds=inputs_embeds,
1599
- use_cache=use_cache,
1600
- output_attentions=output_attentions,
1601
- output_hidden_states=output_hidden_states,
1602
- return_dict=return_dict,
1603
- )
1604
- hidden_states = transformer_outputs[0]
1605
- logits = self.score(hidden_states)
1606
-
1607
- if input_ids is not None:
1608
- batch_size, sequence_length = input_ids.shape[:2]
1609
- else:
1610
- batch_size, sequence_length = inputs_embeds.shape[:2]
1611
-
1612
- assert (
1613
- self.config.pad_token_id is not None or batch_size == 1
1614
- ), "Cannot handle batch sizes > 1 if no padding token is defined."
1615
- if self.config.pad_token_id is None:
1616
- sequence_lengths = -1
1617
- else:
1618
- if input_ids is not None:
1619
- # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1620
- sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1621
- sequence_lengths = sequence_lengths % input_ids.shape[-1]
1622
- sequence_lengths = sequence_lengths.to(logits.device)
1623
- else:
1624
- sequence_lengths = -1
1625
- logger.warning_once(
1626
- f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1627
- "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1628
- )
1629
-
1630
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1631
-
1632
- loss = None
1633
- if labels is not None:
1634
- if self.config.problem_type is None:
1635
- if self.num_labels == 1:
1636
- self.config.problem_type = "regression"
1637
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1638
- self.config.problem_type = "single_label_classification"
1639
- else:
1640
- self.config.problem_type = "multi_label_classification"
1641
-
1642
- if self.config.problem_type == "regression":
1643
- loss_fct = MSELoss()
1644
- if self.num_labels == 1:
1645
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1646
- else:
1647
- loss = loss_fct(pooled_logits, labels)
1648
- elif self.config.problem_type == "single_label_classification":
1649
- loss_fct = CrossEntropyLoss()
1650
- loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1651
- elif self.config.problem_type == "multi_label_classification":
1652
- loss_fct = BCEWithLogitsLoss()
1653
- loss = loss_fct(pooled_logits, labels)
1654
- if not return_dict:
1655
- output = (pooled_logits,) + transformer_outputs[1:]
1656
- return ((loss,) + output) if loss is not None else output
1657
-
1658
- return SequenceClassifierOutputWithPast(
1659
- loss=loss,
1660
- logits=pooled_logits,
1661
- past_key_values=transformer_outputs.past_key_values,
1662
- hidden_states=transformer_outputs.hidden_states,
1663
- attentions=transformer_outputs.attentions,
1664
- )
1665
-
1666
-
1667
- @add_start_docstrings(
1668
- """
1669
- GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1670
- Named-Entity-Recognition (NER) tasks.
1671
- """,
1672
- GPT2_START_DOCSTRING,
1673
- )
1674
- class GPT2ForTokenClassification(GPT2PreTrainedModel):
1675
- def __init__(self, config):
1676
- super().__init__(config)
1677
- self.num_labels = config.num_labels
1678
-
1679
- self.transformer = GPT2Model(config)
1680
- if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
1681
- classifier_dropout = config.classifier_dropout
1682
- elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1683
- classifier_dropout = config.hidden_dropout
1684
- else:
1685
- classifier_dropout = 0.1
1686
- self.dropout = nn.Dropout(classifier_dropout)
1687
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1688
-
1689
- # Model parallel
1690
- self.model_parallel = False
1691
- self.device_map = None
1692
-
1693
- # Initialize weights and apply final processing
1694
- self.post_init()
1695
-
1696
- @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1697
- # fmt: off
1698
- @add_code_sample_docstrings(
1699
- checkpoint="brad1141/gpt2-finetuned-comp2",
1700
- output_type=TokenClassifierOutput,
1701
- config_class=_CONFIG_FOR_DOC,
1702
- expected_loss=0.25,
1703
- expected_output=[
1704
- "Lead",
1705
- "Lead",
1706
- "Lead",
1707
- "Position",
1708
- "Lead",
1709
- "Lead",
1710
- "Lead",
1711
- "Lead",
1712
- "Lead",
1713
- "Lead",
1714
- "Lead",
1715
- "Lead",
1716
- ],
1717
- )
1718
- # fmt: on
1719
- def forward(
1720
- self,
1721
- input_ids: Optional[torch.LongTensor] = None,
1722
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1723
- attention_mask: Optional[torch.FloatTensor] = None,
1724
- token_type_ids: Optional[torch.LongTensor] = None,
1725
- position_ids: Optional[torch.LongTensor] = None,
1726
- head_mask: Optional[torch.FloatTensor] = None,
1727
- inputs_embeds: Optional[torch.FloatTensor] = None,
1728
- labels: Optional[torch.LongTensor] = None,
1729
- use_cache: Optional[bool] = None,
1730
- output_attentions: Optional[bool] = None,
1731
- output_hidden_states: Optional[bool] = None,
1732
- return_dict: Optional[bool] = None,
1733
- ) -> Union[Tuple, TokenClassifierOutput]:
1734
- r"""
1735
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1736
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1737
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1738
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1739
- """
1740
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1741
-
1742
- transformer_outputs = self.transformer(
1743
- input_ids,
1744
- past_key_values=past_key_values,
1745
- attention_mask=attention_mask,
1746
- token_type_ids=token_type_ids,
1747
- position_ids=position_ids,
1748
- head_mask=head_mask,
1749
- inputs_embeds=inputs_embeds,
1750
- use_cache=use_cache,
1751
- output_attentions=output_attentions,
1752
- output_hidden_states=output_hidden_states,
1753
- return_dict=return_dict,
1754
- )
1755
-
1756
- hidden_states = transformer_outputs[0]
1757
- hidden_states = self.dropout(hidden_states)
1758
- logits = self.classifier(hidden_states)
1759
-
1760
- loss = None
1761
- if labels is not None:
1762
- labels = labels.to(logits.device)
1763
- loss_fct = CrossEntropyLoss()
1764
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1765
-
1766
- if not return_dict:
1767
- output = (logits,) + transformer_outputs[2:]
1768
- return ((loss,) + output) if loss is not None else output
1769
-
1770
- return TokenClassifierOutput(
1771
- loss=loss,
1772
- logits=logits,
1773
- hidden_states=transformer_outputs.hidden_states,
1774
- attentions=transformer_outputs.attentions,
1775
- )
1776
-
1777
-
1778
- @add_start_docstrings(
1779
- """
1780
- The GPT-2 Model transformer with a span classification head on top for extractive question-answering tasks like
1781
- SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1782
- """,
1783
- GPT2_START_DOCSTRING,
1784
- )
1785
- class GPT2ForQuestionAnswering(GPT2PreTrainedModel):
1786
- def __init__(self, config):
1787
- super().__init__(config)
1788
- self.num_labels = config.num_labels
1789
- self.transformer = GPT2Model(config)
1790
- self.qa_outputs = nn.Linear(config.hidden_size, 2)
1791
-
1792
- # Model parallel
1793
- self.model_parallel = False
1794
- self.device_map = None
1795
-
1796
- # Initialize weights and apply final processing
1797
- self.post_init()
1798
-
1799
- @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1800
- @add_code_sample_docstrings(
1801
- checkpoint=_CHECKPOINT_FOR_DOC,
1802
- output_type=QuestionAnsweringModelOutput,
1803
- config_class=_CONFIG_FOR_DOC,
1804
- real_checkpoint=_CHECKPOINT_FOR_DOC,
1805
- )
1806
- def forward(
1807
- self,
1808
- input_ids: Optional[torch.LongTensor] = None,
1809
- attention_mask: Optional[torch.FloatTensor] = None,
1810
- token_type_ids: Optional[torch.LongTensor] = None,
1811
- position_ids: Optional[torch.LongTensor] = None,
1812
- head_mask: Optional[torch.FloatTensor] = None,
1813
- inputs_embeds: Optional[torch.FloatTensor] = None,
1814
- start_positions: Optional[torch.LongTensor] = None,
1815
- end_positions: Optional[torch.LongTensor] = None,
1816
- output_attentions: Optional[bool] = None,
1817
- output_hidden_states: Optional[bool] = None,
1818
- return_dict: Optional[bool] = None,
1819
- ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1820
- r"""
1821
- start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1822
- Labels for position (index) of the start of the labelled span for computing the token classification loss.
1823
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1824
- are not taken into account for computing the loss.
1825
- end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1826
- Labels for position (index) of the end of the labelled span for computing the token classification loss.
1827
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1828
- are not taken into account for computing the loss.
1829
- """
1830
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1831
-
1832
- outputs = self.transformer(
1833
- input_ids,
1834
- attention_mask=attention_mask,
1835
- token_type_ids=token_type_ids,
1836
- position_ids=position_ids,
1837
- head_mask=head_mask,
1838
- inputs_embeds=inputs_embeds,
1839
- output_attentions=output_attentions,
1840
- output_hidden_states=output_hidden_states,
1841
- return_dict=return_dict,
1842
- )
1843
-
1844
- sequence_output = outputs[0]
1845
-
1846
- logits = self.qa_outputs(sequence_output)
1847
- start_logits, end_logits = logits.split(1, dim=-1)
1848
- start_logits = start_logits.squeeze(-1).contiguous()
1849
- end_logits = end_logits.squeeze(-1).contiguous()
1850
-
1851
- total_loss = None
1852
- if start_positions is not None and end_positions is not None:
1853
- # If we are on multi-GPU, split add a dimension
1854
- if len(start_positions.size()) > 1:
1855
- start_positions = start_positions.squeeze(-1).to(start_logits.device)
1856
- if len(end_positions.size()) > 1:
1857
- end_positions = end_positions.squeeze(-1).to(end_logits.device)
1858
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
1859
- ignored_index = start_logits.size(1)
1860
- start_positions = start_positions.clamp(0, ignored_index)
1861
- end_positions = end_positions.clamp(0, ignored_index)
1862
-
1863
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1864
- start_loss = loss_fct(start_logits, start_positions)
1865
- end_loss = loss_fct(end_logits, end_positions)
1866
- total_loss = (start_loss + end_loss) / 2
1867
-
1868
- if not return_dict:
1869
- output = (start_logits, end_logits) + outputs[2:]
1870
- return ((total_loss,) + output) if total_loss is not None else output
1871
-
1872
- return QuestionAnsweringModelOutput(
1873
- loss=total_loss,
1874
- start_logits=start_logits,
1875
- end_logits=end_logits,
1876
- hidden_states=outputs.hidden_states,
1877
- attentions=outputs.attentions,
1878
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/gpt/transformers_modeling_utils.py DELETED
The diff for this file is too large to render. See raw diff
 
indextts/infer.py DELETED
@@ -1,690 +0,0 @@
1
- import os
2
-
3
- os.environ['HF_HUB_CACHE'] = './checkpoints/hf_cache'
4
- import time
5
- from subprocess import CalledProcessError
6
- from typing import Dict, List
7
-
8
- import torch
9
- import torchaudio
10
- from torch.nn.utils.rnn import pad_sequence
11
- from omegaconf import OmegaConf
12
- from tqdm import tqdm
13
-
14
- import warnings
15
-
16
- warnings.filterwarnings("ignore", category=FutureWarning)
17
- warnings.filterwarnings("ignore", category=UserWarning)
18
-
19
- from indextts.BigVGAN.models import BigVGAN as Generator
20
- from indextts.gpt.model import UnifiedVoice
21
- from indextts.utils.checkpoint import load_checkpoint
22
- from indextts.utils.feature_extractors import MelSpectrogramFeatures
23
-
24
- from indextts.utils.front import TextNormalizer, TextTokenizer
25
-
26
-
27
- class IndexTTS:
28
- def __init__(
29
- self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_fp16=True, device=None,
30
- use_cuda_kernel=None,
31
- ):
32
- """
33
- Args:
34
- cfg_path (str): path to the config file.
35
- model_dir (str): path to the model directory.
36
- use_fp16 (bool): whether to use fp16.
37
- device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS.
38
- use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device.
39
- """
40
- if device is not None:
41
- self.device = device
42
- self.use_fp16 = False if device == "cpu" else use_fp16
43
- self.use_cuda_kernel = use_cuda_kernel is not None and use_cuda_kernel and device.startswith("cuda")
44
- elif torch.cuda.is_available():
45
- self.device = "cuda:0"
46
- self.use_fp16 = use_fp16
47
- self.use_cuda_kernel = use_cuda_kernel is None or use_cuda_kernel
48
- elif hasattr(torch, "xpu") and torch.xpu.is_available():
49
- self.device = "xpu"
50
- self.use_fp16 = use_fp16
51
- self.use_cuda_kernel = False
52
- elif hasattr(torch, "mps") and torch.backends.mps.is_available():
53
- self.device = "mps"
54
- self.use_fp16 = False # Use float16 on MPS is overhead than float32
55
- self.use_cuda_kernel = False
56
- else:
57
- self.device = "cpu"
58
- self.use_fp16 = False
59
- self.use_cuda_kernel = False
60
- print(">> Be patient, it may take a while to run in CPU mode.")
61
-
62
- self.cfg = OmegaConf.load(cfg_path)
63
- self.model_dir = model_dir
64
- self.dtype = torch.float16 if self.use_fp16 else None
65
- self.stop_mel_token = self.cfg.gpt.stop_mel_token
66
-
67
- # Comment-off to load the VQ-VAE model for debugging tokenizer
68
- # https://github.com/index-tts/index-tts/issues/34
69
- #
70
- # from indextts.vqvae.xtts_dvae import DiscreteVAE
71
- # self.dvae = DiscreteVAE(**self.cfg.vqvae)
72
- # self.dvae_path = os.path.join(self.model_dir, self.cfg.dvae_checkpoint)
73
- # load_checkpoint(self.dvae, self.dvae_path)
74
- # self.dvae = self.dvae.to(self.device)
75
- # if self.use_fp16:
76
- # self.dvae.eval().half()
77
- # else:
78
- # self.dvae.eval()
79
- # print(">> vqvae weights restored from:", self.dvae_path)
80
- self.gpt = UnifiedVoice(**self.cfg.gpt)
81
- self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint)
82
- load_checkpoint(self.gpt, self.gpt_path)
83
- self.gpt = self.gpt.to(self.device)
84
- if self.use_fp16:
85
- self.gpt.eval().half()
86
- else:
87
- self.gpt.eval()
88
- print(">> GPT weights restored from:", self.gpt_path)
89
- if self.use_fp16:
90
- try:
91
- import deepspeed
92
-
93
- use_deepspeed = True
94
- except (ImportError, OSError, CalledProcessError) as e:
95
- use_deepspeed = False
96
- print(f">> DeepSpeed加载失败,回退到标准推理: {e}")
97
-
98
- self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=True)
99
- else:
100
- self.gpt.post_init_gpt2_config(use_deepspeed=False, kv_cache=False, half=False)
101
-
102
- if self.use_cuda_kernel:
103
- # preload the CUDA kernel for BigVGAN
104
- try:
105
- from indextts.BigVGAN.alias_free_activation.cuda import load
106
-
107
- anti_alias_activation_cuda = load.load()
108
- print(">> Preload custom CUDA kernel for BigVGAN", anti_alias_activation_cuda)
109
- except:
110
- print(">> Failed to load custom CUDA kernel for BigVGAN. Falling back to torch.")
111
- self.use_cuda_kernel = False
112
- self.bigvgan = Generator(self.cfg.bigvgan, use_cuda_kernel=self.use_cuda_kernel)
113
- self.bigvgan_path = os.path.join(self.model_dir, self.cfg.bigvgan_checkpoint)
114
- vocoder_dict = torch.load(self.bigvgan_path, map_location="cpu")
115
- self.bigvgan.load_state_dict(vocoder_dict["generator"])
116
- self.bigvgan = self.bigvgan.to(self.device)
117
- # remove weight norm on eval mode
118
- self.bigvgan.remove_weight_norm()
119
- self.bigvgan.eval()
120
- print(">> bigvgan weights restored from:", self.bigvgan_path)
121
- self.bpe_path = os.path.join(self.model_dir, self.cfg.dataset["bpe_model"])
122
- self.normalizer = TextNormalizer()
123
- self.normalizer.load()
124
- print(">> TextNormalizer loaded")
125
- self.tokenizer = TextTokenizer(self.bpe_path, self.normalizer)
126
- print(">> bpe model loaded from:", self.bpe_path)
127
- # 缓存参考音频mel:
128
- self.cache_audio_prompt = None
129
- self.cache_cond_mel = None
130
- # 进度引用显示(可选)
131
- self.gr_progress = None
132
- self.model_version = self.cfg.version if hasattr(self.cfg, "version") else None
133
-
134
- def remove_long_silence(self, codes: torch.Tensor, silent_token=52, max_consecutive=30):
135
- """
136
- Shrink special tokens (silent_token and stop_mel_token) in codes
137
- codes: [B, T]
138
- """
139
- code_lens = []
140
- codes_list = []
141
- device = codes.device
142
- dtype = codes.dtype
143
- isfix = False
144
- for i in range(0, codes.shape[0]):
145
- code = codes[i]
146
- if not torch.any(code == self.stop_mel_token).item():
147
- len_ = code.size(0)
148
- else:
149
- stop_mel_idx = (code == self.stop_mel_token).nonzero(as_tuple=False)
150
- len_ = stop_mel_idx[0].item() if len(stop_mel_idx) > 0 else code.size(0)
151
-
152
- count = torch.sum(code == silent_token).item()
153
- if count > max_consecutive:
154
- # code = code.cpu().tolist()
155
- ncode_idx = []
156
- n = 0
157
- for k in range(len_):
158
- assert code[
159
- k] != self.stop_mel_token, f"stop_mel_token {self.stop_mel_token} should be shrinked here"
160
- if code[k] != silent_token:
161
- ncode_idx.append(k)
162
- n = 0
163
- elif code[k] == silent_token and n < 10:
164
- ncode_idx.append(k)
165
- n += 1
166
- # if (k == 0 and code[k] == 52) or (code[k] == 52 and code[k-1] == 52):
167
- # n += 1
168
- # new code
169
- len_ = len(ncode_idx)
170
- codes_list.append(code[ncode_idx])
171
- isfix = True
172
- else:
173
- # shrink to len_
174
- codes_list.append(code[:len_])
175
- code_lens.append(len_)
176
- if isfix:
177
- if len(codes_list) > 1:
178
- codes = pad_sequence(codes_list, batch_first=True, padding_value=self.stop_mel_token)
179
- else:
180
- codes = codes_list[0].unsqueeze(0)
181
- else:
182
- # unchanged
183
- pass
184
- # clip codes to max length
185
- max_len = max(code_lens)
186
- if max_len < codes.shape[1]:
187
- codes = codes[:, :max_len]
188
- code_lens = torch.tensor(code_lens, dtype=torch.long, device=device)
189
- return codes, code_lens
190
-
191
- def bucket_segments(self, segments, bucket_max_size=4) -> List[List[Dict]]:
192
- """
193
- Segment data bucketing.
194
- if ``bucket_max_size=1``, return all segments in one bucket.
195
- """
196
- outputs: List[Dict] = []
197
- for idx, sent in enumerate(segments):
198
- outputs.append({"idx": idx, "sent": sent, "len": len(sent)})
199
-
200
- if len(outputs) > bucket_max_size:
201
- # split segments into buckets by segment length
202
- buckets: List[List[Dict]] = []
203
- factor = 1.5
204
- last_bucket = None
205
- last_bucket_sent_len_median = 0
206
-
207
- for sent in sorted(outputs, key=lambda x: x["len"]):
208
- current_sent_len = sent["len"]
209
- if current_sent_len == 0:
210
- print(">> skip empty segment")
211
- continue
212
- if last_bucket is None \
213
- or current_sent_len >= int(last_bucket_sent_len_median * factor) \
214
- or len(last_bucket) >= bucket_max_size:
215
- # new bucket
216
- buckets.append([sent])
217
- last_bucket = buckets[-1]
218
- last_bucket_sent_len_median = current_sent_len
219
- else:
220
- # current bucket can hold more segments
221
- last_bucket.append(sent) # sorted
222
- mid = len(last_bucket) // 2
223
- last_bucket_sent_len_median = last_bucket[mid]["len"]
224
- last_bucket = None
225
- # merge all buckets with size 1
226
- out_buckets: List[List[Dict]] = []
227
- only_ones: List[Dict] = []
228
- for b in buckets:
229
- if len(b) == 1:
230
- only_ones.append(b[0])
231
- else:
232
- out_buckets.append(b)
233
- if len(only_ones) > 0:
234
- # merge into previous buckets if possible
235
- # print("only_ones:", [(o["idx"], o["len"]) for o in only_ones])
236
- for i in range(len(out_buckets)):
237
- b = out_buckets[i]
238
- if len(b) < bucket_max_size:
239
- b.append(only_ones.pop(0))
240
- if len(only_ones) == 0:
241
- break
242
- # combined all remaining sized 1 buckets
243
- if len(only_ones) > 0:
244
- out_buckets.extend(
245
- [only_ones[i:i + bucket_max_size] for i in range(0, len(only_ones), bucket_max_size)])
246
- return out_buckets
247
- return [outputs]
248
-
249
- def pad_tokens_cat(self, tokens: List[torch.Tensor]) -> torch.Tensor:
250
- if self.model_version and self.model_version >= 1.5:
251
- # 1.5版本以上,直接使用stop_text_token 右侧填充,填充到最大长度
252
- # [1, N] -> [N,]
253
- tokens = [t.squeeze(0) for t in tokens]
254
- return pad_sequence(tokens, batch_first=True, padding_value=self.cfg.gpt.stop_text_token,
255
- padding_side="right")
256
- max_len = max(t.size(1) for t in tokens)
257
- outputs = []
258
- for tensor in tokens:
259
- pad_len = max_len - tensor.size(1)
260
- if pad_len > 0:
261
- n = min(8, pad_len)
262
- tensor = torch.nn.functional.pad(tensor, (0, n), value=self.cfg.gpt.stop_text_token)
263
- tensor = torch.nn.functional.pad(tensor, (0, pad_len - n), value=self.cfg.gpt.start_text_token)
264
- tensor = tensor[:, :max_len]
265
- outputs.append(tensor)
266
- tokens = torch.cat(outputs, dim=0)
267
- return tokens
268
-
269
- def torch_empty_cache(self):
270
- try:
271
- if "cuda" in str(self.device):
272
- torch.cuda.empty_cache()
273
- elif "mps" in str(self.device):
274
- torch.mps.empty_cache()
275
- except Exception as e:
276
- pass
277
-
278
- def _set_gr_progress(self, value, desc):
279
- if self.gr_progress is not None:
280
- self.gr_progress(value, desc=desc)
281
-
282
- # 快速推理:对于“多句长文本”,可实现至少 2~10 倍以上的速度提升~ (First modified by sunnyboxs 2025-04-16)
283
- def infer_fast(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_segment=100,
284
- segments_bucket_max_size=4, **generation_kwargs):
285
- """
286
- Args:
287
- ``max_text_tokens_per_segment``: 分句的最大token数,默认``100``,可以根据GPU硬件情况调整
288
- - 越小,batch 越多,推理速度越*快*,占用内存更多,可能影响质量
289
- - 越大,batch 越少,推理速度越*慢*,占用内存和质量更接近于非快速推理
290
- ``segments_bucket_max_size``: 分句分桶的最大容量,默认``4``,可以根据GPU内存调整
291
- - 越大,bucket数量越少,batch越多,推理速度越*快*,占用内存更多,可能影响质量
292
- - 越小,bucket数量越多,batch越少,推理速度越*慢*,占用内存和质量更接近于非快速推理
293
- """
294
- print(">> starting fast inference...")
295
-
296
- self._set_gr_progress(0, "starting fast inference...")
297
- if verbose:
298
- print(f"origin text:{text}")
299
- start_time = time.perf_counter()
300
-
301
- # 如果参考音频改变了,才需要重新生成 cond_mel, 提升速度
302
- if self.cache_cond_mel is None or self.cache_audio_prompt != audio_prompt:
303
- audio, sr = torchaudio.load(audio_prompt)
304
- audio = torch.mean(audio, dim=0, keepdim=True)
305
- if audio.shape[0] > 1:
306
- audio = audio[0].unsqueeze(0)
307
- audio = torchaudio.transforms.Resample(sr, 24000)(audio)
308
-
309
- max_audio_length_seconds = 50
310
- max_audio_samples = int(max_audio_length_seconds * 24000)
311
-
312
- if audio.shape[1] > max_audio_samples:
313
- if verbose:
314
- print(f"Audio too long ({audio.shape[1]} samples), truncating to {max_audio_samples} samples")
315
- audio = audio[:, :max_audio_samples]
316
-
317
- cond_mel = MelSpectrogramFeatures()(audio).to(self.device)
318
- cond_mel_frame = cond_mel.shape[-1]
319
- if verbose:
320
- print(f"cond_mel shape: {cond_mel.shape}", "dtype:", cond_mel.dtype)
321
-
322
- self.cache_audio_prompt = audio_prompt
323
- self.cache_cond_mel = cond_mel
324
- else:
325
- cond_mel = self.cache_cond_mel
326
- cond_mel_frame = cond_mel.shape[-1]
327
- pass
328
-
329
- auto_conditioning = cond_mel
330
- cond_mel_lengths = torch.tensor([cond_mel_frame], device=self.device)
331
-
332
- # text_tokens
333
- text_tokens_list = self.tokenizer.tokenize(text)
334
-
335
- segments = self.tokenizer.split_segments(text_tokens_list,
336
- max_text_tokens_per_segment=max_text_tokens_per_segment)
337
- if verbose:
338
- print(">> text token count:", len(text_tokens_list))
339
- print(" segments count:", len(segments))
340
- print(" max_text_tokens_per_segment:", max_text_tokens_per_segment)
341
- print(*segments, sep="\n")
342
- do_sample = generation_kwargs.pop("do_sample", True)
343
- top_p = generation_kwargs.pop("top_p", 0.8)
344
- top_k = generation_kwargs.pop("top_k", 30)
345
- temperature = generation_kwargs.pop("temperature", 1.0)
346
- autoregressive_batch_size = 1
347
- length_penalty = generation_kwargs.pop("length_penalty", 0.0)
348
- num_beams = generation_kwargs.pop("num_beams", 3)
349
- repetition_penalty = generation_kwargs.pop("repetition_penalty", 10.0)
350
- max_mel_tokens = generation_kwargs.pop("max_mel_tokens", 600)
351
- sampling_rate = 24000
352
- # lang = "EN"
353
- # lang = "ZH"
354
- wavs = []
355
- gpt_gen_time = 0
356
- gpt_forward_time = 0
357
- bigvgan_time = 0
358
-
359
- # text processing
360
- all_text_tokens: List[List[torch.Tensor]] = []
361
- self._set_gr_progress(0.1, "text processing...")
362
- bucket_max_size = segments_bucket_max_size if self.device != "cpu" else 1
363
- all_segments = self.bucket_segments(segments, bucket_max_size=bucket_max_size)
364
- bucket_count = len(all_segments)
365
- if verbose:
366
- print(">> segments bucket_count:", bucket_count,
367
- "bucket sizes:", [(len(s), [t["idx"] for t in s]) for s in all_segments],
368
- "bucket_max_size:", bucket_max_size)
369
- for segments in all_segments:
370
- temp_tokens: List[torch.Tensor] = []
371
- all_text_tokens.append(temp_tokens)
372
- for item in segments:
373
- sent = item["sent"]
374
- text_tokens = self.tokenizer.convert_tokens_to_ids(sent)
375
- text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0)
376
- if verbose:
377
- print(text_tokens)
378
- print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
379
- # debug tokenizer
380
- text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist())
381
- print("text_token_syms is same as segment tokens", text_token_syms == sent)
382
- temp_tokens.append(text_tokens)
383
-
384
- # Sequential processing of bucketing data
385
- all_batch_num = sum(len(s) for s in all_segments)
386
- all_batch_codes = []
387
- processed_num = 0
388
- for item_tokens in all_text_tokens:
389
- batch_num = len(item_tokens)
390
- if batch_num > 1:
391
- batch_text_tokens = self.pad_tokens_cat(item_tokens)
392
- else:
393
- batch_text_tokens = item_tokens[0]
394
- processed_num += batch_num
395
- # gpt speech
396
- self._set_gr_progress(0.2 + 0.3 * processed_num / all_batch_num,
397
- f"gpt speech inference {processed_num}/{all_batch_num}...")
398
- m_start_time = time.perf_counter()
399
- with torch.no_grad():
400
- with torch.amp.autocast(batch_text_tokens.device.type, enabled=self.dtype is not None,
401
- dtype=self.dtype):
402
- temp_codes = self.gpt.inference_speech(auto_conditioning, batch_text_tokens,
403
- cond_mel_lengths=cond_mel_lengths,
404
- # text_lengths=text_len,
405
- do_sample=do_sample,
406
- top_p=top_p,
407
- top_k=top_k,
408
- temperature=temperature,
409
- num_return_sequences=autoregressive_batch_size,
410
- length_penalty=length_penalty,
411
- num_beams=num_beams,
412
- repetition_penalty=repetition_penalty,
413
- max_generate_length=max_mel_tokens,
414
- **generation_kwargs)
415
- all_batch_codes.append(temp_codes)
416
- gpt_gen_time += time.perf_counter() - m_start_time
417
-
418
- # gpt latent
419
- self._set_gr_progress(0.5, "gpt latents inference...")
420
- all_idxs = []
421
- all_latents = []
422
- has_warned = False
423
- for batch_codes, batch_tokens, batch_segments in zip(all_batch_codes, all_text_tokens, all_segments):
424
- for i in range(batch_codes.shape[0]):
425
- codes = batch_codes[i] # [x]
426
- if not has_warned and codes[-1] != self.stop_mel_token:
427
- warnings.warn(
428
- f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). "
429
- f"Consider reducing `max_text_tokens_per_segment`({max_text_tokens_per_segment}) or increasing `max_mel_tokens`.",
430
- category=RuntimeWarning
431
- )
432
- has_warned = True
433
- codes = codes.unsqueeze(0) # [x] -> [1, x]
434
- if verbose:
435
- print("codes:", codes.shape)
436
- print(codes)
437
- codes, code_lens = self.remove_long_silence(codes, silent_token=52, max_consecutive=30)
438
- if verbose:
439
- print("fix codes:", codes.shape)
440
- print(codes)
441
- print("code_lens:", code_lens)
442
- text_tokens = batch_tokens[i]
443
- all_idxs.append(batch_segments[i]["idx"])
444
- m_start_time = time.perf_counter()
445
- with torch.no_grad():
446
- with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
447
- latent = \
448
- self.gpt(auto_conditioning, text_tokens,
449
- torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
450
- code_lens * self.gpt.mel_length_compression,
451
- cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]],
452
- device=text_tokens.device),
453
- return_latent=True, clip_inputs=False)
454
- gpt_forward_time += time.perf_counter() - m_start_time
455
- all_latents.append(latent)
456
- del all_batch_codes, all_text_tokens, all_segments
457
- # bigvgan chunk
458
- chunk_size = 2
459
- all_latents = [all_latents[all_idxs.index(i)] for i in range(len(all_latents))]
460
- if verbose:
461
- print(">> all_latents:", len(all_latents))
462
- print(" latents length:", [l.shape[1] for l in all_latents])
463
- chunk_latents = [all_latents[i: i + chunk_size] for i in range(0, len(all_latents), chunk_size)]
464
- chunk_length = len(chunk_latents)
465
- latent_length = len(all_latents)
466
-
467
- # bigvgan chunk decode
468
- self._set_gr_progress(0.7, "bigvgan decoding...")
469
- tqdm_progress = tqdm(total=latent_length, desc="bigvgan")
470
- for items in chunk_latents:
471
- tqdm_progress.update(len(items))
472
- latent = torch.cat(items, dim=1)
473
- with torch.no_grad():
474
- with torch.amp.autocast(latent.device.type, enabled=self.dtype is not None, dtype=self.dtype):
475
- m_start_time = time.perf_counter()
476
- wav, _ = self.bigvgan(latent, auto_conditioning.transpose(1, 2))
477
- bigvgan_time += time.perf_counter() - m_start_time
478
- wav = wav.squeeze(1)
479
- pass
480
- wav = torch.clamp(32767 * wav, -32767.0, 32767.0)
481
- wavs.append(wav.cpu()) # to cpu before saving
482
-
483
- # clear cache
484
- tqdm_progress.close() # 确保进度条被关闭
485
- del all_latents, chunk_latents
486
- end_time = time.perf_counter()
487
- self.torch_empty_cache()
488
-
489
- # wav audio output
490
- self._set_gr_progress(0.9, "saving audio...")
491
- wav = torch.cat(wavs, dim=1)
492
- wav_length = wav.shape[-1] / sampling_rate
493
- print(f">> Reference audio length: {cond_mel_frame * 256 / sampling_rate:.2f} seconds")
494
- print(f">> gpt_gen_time: {gpt_gen_time:.2f} seconds")
495
- print(f">> gpt_forward_time: {gpt_forward_time:.2f} seconds")
496
- print(f">> bigvgan_time: {bigvgan_time:.2f} seconds")
497
- print(f">> Total fast inference time: {end_time - start_time:.2f} seconds")
498
- print(f">> Generated audio length: {wav_length:.2f} seconds")
499
- print(f">> [fast] bigvgan chunk_length: {chunk_length}")
500
- print(f">> [fast] batch_num: {all_batch_num} bucket_max_size: {bucket_max_size}",
501
- f"bucket_count: {bucket_count}" if bucket_max_size > 1 else "")
502
- print(f">> [fast] RTF: {(end_time - start_time) / wav_length:.4f}")
503
-
504
- # save audio
505
- wav = wav.cpu() # to cpu
506
- if output_path:
507
- # 直接保存音频到指定路径中
508
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
509
- torchaudio.save(output_path, wav.type(torch.int16), sampling_rate)
510
- print(">> wav file saved to:", output_path)
511
- return output_path
512
- else:
513
- # 返回以符合Gradio的格式要求
514
- wav_data = wav.type(torch.int16)
515
- wav_data = wav_data.numpy().T
516
- return (sampling_rate, wav_data)
517
-
518
- # 原始推理模式
519
- def infer(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_segment=120,
520
- **generation_kwargs):
521
- print(">> starting inference...")
522
- self._set_gr_progress(0, "starting inference...")
523
- if verbose:
524
- print(f"origin text:{text}")
525
- start_time = time.perf_counter()
526
-
527
- # 如果参考音频改变了,才需要重新生成 cond_mel, 提升速度
528
- if self.cache_cond_mel is None or self.cache_audio_prompt != audio_prompt:
529
- audio, sr = torchaudio.load(audio_prompt)
530
- audio = torch.mean(audio, dim=0, keepdim=True)
531
- if audio.shape[0] > 1:
532
- audio = audio[0].unsqueeze(0)
533
- audio = torchaudio.transforms.Resample(sr, 24000)(audio)
534
- cond_mel = MelSpectrogramFeatures()(audio).to(self.device)
535
- cond_mel_frame = cond_mel.shape[-1]
536
- if verbose:
537
- print(f"cond_mel shape: {cond_mel.shape}", "dtype:", cond_mel.dtype)
538
-
539
- self.cache_audio_prompt = audio_prompt
540
- self.cache_cond_mel = cond_mel
541
- else:
542
- cond_mel = self.cache_cond_mel
543
- cond_mel_frame = cond_mel.shape[-1]
544
- pass
545
-
546
- self._set_gr_progress(0.1, "text processing...")
547
- auto_conditioning = cond_mel
548
- text_tokens_list = self.tokenizer.tokenize(text)
549
- segments = self.tokenizer.split_segments(text_tokens_list, max_text_tokens_per_segment)
550
- if verbose:
551
- print("text token count:", len(text_tokens_list))
552
- print("segments count:", len(segments))
553
- print("max_text_tokens_per_segment:", max_text_tokens_per_segment)
554
- print(*segments, sep="\n")
555
- do_sample = generation_kwargs.pop("do_sample", True)
556
- top_p = generation_kwargs.pop("top_p", 0.8)
557
- top_k = generation_kwargs.pop("top_k", 30)
558
- temperature = generation_kwargs.pop("temperature", 1.0)
559
- autoregressive_batch_size = 1
560
- length_penalty = generation_kwargs.pop("length_penalty", 0.0)
561
- num_beams = generation_kwargs.pop("num_beams", 3)
562
- repetition_penalty = generation_kwargs.pop("repetition_penalty", 10.0)
563
- max_mel_tokens = generation_kwargs.pop("max_mel_tokens", 600)
564
- sampling_rate = 24000
565
- # lang = "EN"
566
- # lang = "ZH"
567
- wavs = []
568
- gpt_gen_time = 0
569
- gpt_forward_time = 0
570
- bigvgan_time = 0
571
- progress = 0
572
- has_warned = False
573
- for sent in segments:
574
- text_tokens = self.tokenizer.convert_tokens_to_ids(sent)
575
- text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0)
576
- # text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
577
- # text_tokens = F.pad(text_tokens, (1, 0), value=0)
578
- # text_tokens = F.pad(text_tokens, (0, 1), value=1)
579
- if verbose:
580
- print(text_tokens)
581
- print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
582
- # debug tokenizer
583
- text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist())
584
- print("text_token_syms is same as segment tokens", text_token_syms == sent)
585
-
586
- # text_len = torch.IntTensor([text_tokens.size(1)], device=text_tokens.device)
587
- # print(text_len)
588
- progress += 1
589
- self._set_gr_progress(0.2 + 0.4 * (progress - 1) / len(segments),
590
- f"gpt latents inference {progress}/{len(segments)}...")
591
- m_start_time = time.perf_counter()
592
- with torch.no_grad():
593
- with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
594
- codes = self.gpt.inference_speech(auto_conditioning, text_tokens,
595
- cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]],
596
- device=text_tokens.device),
597
- # text_lengths=text_len,
598
- do_sample=do_sample,
599
- top_p=top_p,
600
- top_k=top_k,
601
- temperature=temperature,
602
- num_return_sequences=autoregressive_batch_size,
603
- length_penalty=length_penalty,
604
- num_beams=num_beams,
605
- repetition_penalty=repetition_penalty,
606
- max_generate_length=max_mel_tokens,
607
- **generation_kwargs)
608
- gpt_gen_time += time.perf_counter() - m_start_time
609
- if not has_warned and (codes[:, -1] != self.stop_mel_token).any():
610
- warnings.warn(
611
- f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). "
612
- f"Input text tokens: {text_tokens.shape[1]}. "
613
- f"Consider reducing `max_text_tokens_per_segment`({max_text_tokens_per_segment}) or increasing `max_mel_tokens`.",
614
- category=RuntimeWarning
615
- )
616
- has_warned = True
617
-
618
- code_lens = torch.tensor([codes.shape[-1]], device=codes.device, dtype=codes.dtype)
619
- if verbose:
620
- print(codes, type(codes))
621
- print(f"codes shape: {codes.shape}, codes type: {codes.dtype}")
622
- print(f"code len: {code_lens}")
623
-
624
- # remove ultra-long silence if exits
625
- # temporarily fix the long silence bug.
626
- codes, code_lens = self.remove_long_silence(codes, silent_token=52, max_consecutive=30)
627
- if verbose:
628
- print(codes, type(codes))
629
- print(f"fix codes shape: {codes.shape}, codes type: {codes.dtype}")
630
- print(f"code len: {code_lens}")
631
- self._set_gr_progress(0.2 + 0.4 * progress / len(segments),
632
- f"gpt speech inference {progress}/{len(segments)}...")
633
- m_start_time = time.perf_counter()
634
- # latent, text_lens_out, code_lens_out = \
635
- with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
636
- latent = \
637
- self.gpt(auto_conditioning, text_tokens,
638
- torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
639
- code_lens * self.gpt.mel_length_compression,
640
- cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]],
641
- device=text_tokens.device),
642
- return_latent=True, clip_inputs=False)
643
- gpt_forward_time += time.perf_counter() - m_start_time
644
-
645
- m_start_time = time.perf_counter()
646
- wav, _ = self.bigvgan(latent, auto_conditioning.transpose(1, 2))
647
- bigvgan_time += time.perf_counter() - m_start_time
648
- wav = wav.squeeze(1)
649
-
650
- wav = torch.clamp(32767 * wav, -32767.0, 32767.0)
651
- if verbose:
652
- print(f"wav shape: {wav.shape}", "min:", wav.min(), "max:", wav.max())
653
- # wavs.append(wav[:, :-512])
654
- wavs.append(wav.cpu()) # to cpu before saving
655
- end_time = time.perf_counter()
656
- self._set_gr_progress(0.9, "saving audio...")
657
- wav = torch.cat(wavs, dim=1)
658
- wav_length = wav.shape[-1] / sampling_rate
659
- print(f">> Reference audio length: {cond_mel_frame * 256 / sampling_rate:.2f} seconds")
660
- print(f">> gpt_gen_time: {gpt_gen_time:.2f} seconds")
661
- print(f">> gpt_forward_time: {gpt_forward_time:.2f} seconds")
662
- print(f">> bigvgan_time: {bigvgan_time:.2f} seconds")
663
- print(f">> Total inference time: {end_time - start_time:.2f} seconds")
664
- print(f">> Generated audio length: {wav_length:.2f} seconds")
665
- print(f">> RTF: {(end_time - start_time) / wav_length:.4f}")
666
-
667
- # save audio
668
- wav = wav.cpu() # to cpu
669
- if output_path:
670
- # 直接保存音频到指定路径中
671
- if os.path.isfile(output_path):
672
- os.remove(output_path)
673
- print(">> remove old wav file:", output_path)
674
- if os.path.dirname(output_path) != "":
675
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
676
- torchaudio.save(output_path, wav.type(torch.int16), sampling_rate)
677
- print(">> wav file saved to:", output_path)
678
- return output_path
679
- else:
680
- # 返回以符合Gradio的格式要求
681
- wav_data = wav.type(torch.int16)
682
- wav_data = wav_data.numpy().T
683
- return (sampling_rate, wav_data)
684
-
685
- if __name__ == "__main__":
686
- prompt_wav = "examples/voice_01.wav"
687
- text = '欢迎大家来体验indextts2,并给予我们意见与反馈,谢谢大家。'
688
-
689
- tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_cuda_kernel=False)
690
- tts.infer(audio_prompt=prompt_wav, text=text, output_path="gen.wav", verbose=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/infer_v2.py DELETED
@@ -1,739 +0,0 @@
1
- import os
2
- from subprocess import CalledProcessError
3
-
4
- os.environ['HF_HUB_CACHE'] = './checkpoints/hf_cache'
5
- import json
6
- import re
7
- import time
8
- import librosa
9
- import torch
10
- import torchaudio
11
- from torch.nn.utils.rnn import pad_sequence
12
-
13
- import warnings
14
-
15
- warnings.filterwarnings("ignore", category=FutureWarning)
16
- warnings.filterwarnings("ignore", category=UserWarning)
17
-
18
- from omegaconf import OmegaConf
19
-
20
- from indextts.gpt.model_v2 import UnifiedVoice
21
- from indextts.utils.maskgct_utils import build_semantic_model, build_semantic_codec
22
- from indextts.utils.checkpoint import load_checkpoint
23
- from indextts.utils.front import TextNormalizer, TextTokenizer
24
-
25
- from indextts.s2mel.modules.commons import load_checkpoint2, MyModel
26
- from indextts.s2mel.modules.bigvgan import bigvgan
27
- from indextts.s2mel.modules.campplus.DTDNN import CAMPPlus
28
- from indextts.s2mel.modules.audio import mel_spectrogram
29
-
30
- from transformers import AutoTokenizer
31
- from modelscope import AutoModelForCausalLM
32
- from huggingface_hub import hf_hub_download
33
- import safetensors
34
- from transformers import SeamlessM4TFeatureExtractor
35
- import random
36
- import torch.nn.functional as F
37
-
38
- class IndexTTS2:
39
- def __init__(
40
- self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_fp16=False, device=None,
41
- use_cuda_kernel=None,use_deepspeed=False
42
- ):
43
- """
44
- Args:
45
- cfg_path (str): path to the config file.
46
- model_dir (str): path to the model directory.
47
- use_fp16 (bool): whether to use fp16.
48
- device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS.
49
- use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device.
50
- use_deepspeed (bool): whether to use DeepSpeed or not.
51
- """
52
- if device is not None:
53
- self.device = device
54
- self.use_fp16 = False if device == "cpu" else use_fp16
55
- self.use_cuda_kernel = use_cuda_kernel is not None and use_cuda_kernel and device.startswith("cuda")
56
- elif torch.cuda.is_available():
57
- self.device = "cuda:0"
58
- self.use_fp16 = use_fp16
59
- self.use_cuda_kernel = use_cuda_kernel is None or use_cuda_kernel
60
- elif hasattr(torch, "xpu") and torch.xpu.is_available():
61
- self.device = "xpu"
62
- self.use_fp16 = use_fp16
63
- self.use_cuda_kernel = False
64
- elif hasattr(torch, "mps") and torch.backends.mps.is_available():
65
- self.device = "mps"
66
- self.use_fp16 = False # Use float16 on MPS is overhead than float32
67
- self.use_cuda_kernel = False
68
- else:
69
- self.device = "cpu"
70
- self.use_fp16 = False
71
- self.use_cuda_kernel = False
72
- print(">> Be patient, it may take a while to run in CPU mode.")
73
-
74
- self.cfg = OmegaConf.load(cfg_path)
75
- self.model_dir = model_dir
76
- self.dtype = torch.float16 if self.use_fp16 else None
77
- self.stop_mel_token = self.cfg.gpt.stop_mel_token
78
-
79
- self.qwen_emo = QwenEmotion(os.path.join(self.model_dir, self.cfg.qwen_emo_path))
80
-
81
- self.gpt = UnifiedVoice(**self.cfg.gpt)
82
- self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint)
83
- load_checkpoint(self.gpt, self.gpt_path)
84
- self.gpt = self.gpt.to(self.device)
85
- if self.use_fp16:
86
- self.gpt.eval().half()
87
- else:
88
- self.gpt.eval()
89
- print(">> GPT weights restored from:", self.gpt_path)
90
-
91
- if use_deepspeed:
92
- try:
93
- import deepspeed
94
- except (ImportError, OSError, CalledProcessError) as e:
95
- use_deepspeed = False
96
- print(f">> Failed to load DeepSpeed. Falling back to normal inference. Error: {e}")
97
-
98
- self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=self.use_fp16)
99
-
100
- if self.use_cuda_kernel:
101
- # preload the CUDA kernel for BigVGAN
102
- try:
103
- from indextts.s2mel.modules.bigvgan.alias_free_activation.cuda import activation1d
104
-
105
- print(">> Preload custom CUDA kernel for BigVGAN", activation1d.anti_alias_activation_cuda)
106
- except Exception as e:
107
- print(">> Failed to load custom CUDA kernel for BigVGAN. Falling back to torch.")
108
- print(f"{e!r}")
109
- self.use_cuda_kernel = False
110
-
111
- self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
112
- self.semantic_model, self.semantic_mean, self.semantic_std = build_semantic_model(
113
- os.path.join(self.model_dir, self.cfg.w2v_stat))
114
- self.semantic_model = self.semantic_model.to(self.device)
115
- self.semantic_model.eval()
116
- self.semantic_mean = self.semantic_mean.to(self.device)
117
- self.semantic_std = self.semantic_std.to(self.device)
118
-
119
- semantic_codec = build_semantic_codec(self.cfg.semantic_codec)
120
- semantic_code_ckpt = hf_hub_download("amphion/MaskGCT", filename="semantic_codec/model.safetensors")
121
- safetensors.torch.load_model(semantic_codec, semantic_code_ckpt)
122
- self.semantic_codec = semantic_codec.to(self.device)
123
- self.semantic_codec.eval()
124
- print('>> semantic_codec weights restored from: {}'.format(semantic_code_ckpt))
125
-
126
- s2mel_path = os.path.join(self.model_dir, self.cfg.s2mel_checkpoint)
127
- s2mel = MyModel(self.cfg.s2mel, use_gpt_latent=True)
128
- s2mel, _, _, _ = load_checkpoint2(
129
- s2mel,
130
- None,
131
- s2mel_path,
132
- load_only_params=True,
133
- ignore_modules=[],
134
- is_distributed=False,
135
- )
136
- self.s2mel = s2mel.to(self.device)
137
- self.s2mel.models['cfm'].estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
138
- self.s2mel.eval()
139
- print(">> s2mel weights restored from:", s2mel_path)
140
-
141
- # load campplus_model
142
- campplus_ckpt_path = hf_hub_download(
143
- "funasr/campplus", filename="campplus_cn_common.bin"
144
- )
145
- campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
146
- campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu"))
147
- self.campplus_model = campplus_model.to(self.device)
148
- self.campplus_model.eval()
149
- print(">> campplus_model weights restored from:", campplus_ckpt_path)
150
-
151
- bigvgan_name = self.cfg.vocoder.name
152
- self.bigvgan = bigvgan.BigVGAN.from_pretrained(bigvgan_name, use_cuda_kernel=self.use_cuda_kernel)
153
- self.bigvgan = self.bigvgan.to(self.device)
154
- self.bigvgan.remove_weight_norm()
155
- self.bigvgan.eval()
156
- print(">> bigvgan weights restored from:", bigvgan_name)
157
-
158
- self.bpe_path = os.path.join(self.model_dir, self.cfg.dataset["bpe_model"])
159
- self.normalizer = TextNormalizer()
160
- self.normalizer.load()
161
- print(">> TextNormalizer loaded")
162
- self.tokenizer = TextTokenizer(self.bpe_path, self.normalizer)
163
- print(">> bpe model loaded from:", self.bpe_path)
164
-
165
- emo_matrix = torch.load(os.path.join(self.model_dir, self.cfg.emo_matrix))
166
- self.emo_matrix = emo_matrix.to(self.device)
167
- self.emo_num = list(self.cfg.emo_num)
168
-
169
- spk_matrix = torch.load(os.path.join(self.model_dir, self.cfg.spk_matrix))
170
- self.spk_matrix = spk_matrix.to(self.device)
171
-
172
- self.emo_matrix = torch.split(self.emo_matrix, self.emo_num)
173
- self.spk_matrix = torch.split(self.spk_matrix, self.emo_num)
174
-
175
- mel_fn_args = {
176
- "n_fft": self.cfg.s2mel['preprocess_params']['spect_params']['n_fft'],
177
- "win_size": self.cfg.s2mel['preprocess_params']['spect_params']['win_length'],
178
- "hop_size": self.cfg.s2mel['preprocess_params']['spect_params']['hop_length'],
179
- "num_mels": self.cfg.s2mel['preprocess_params']['spect_params']['n_mels'],
180
- "sampling_rate": self.cfg.s2mel["preprocess_params"]["sr"],
181
- "fmin": self.cfg.s2mel['preprocess_params']['spect_params'].get('fmin', 0),
182
- "fmax": None if self.cfg.s2mel['preprocess_params']['spect_params'].get('fmax', "None") == "None" else 8000,
183
- "center": False
184
- }
185
- self.mel_fn = lambda x: mel_spectrogram(x, **mel_fn_args)
186
-
187
- # 缓存参考音频:
188
- self.cache_spk_cond = None
189
- self.cache_s2mel_style = None
190
- self.cache_s2mel_prompt = None
191
- self.cache_spk_audio_prompt = None
192
- self.cache_emo_cond = None
193
- self.cache_emo_audio_prompt = None
194
- self.cache_mel = None
195
-
196
- # 进度引用显示(可选)
197
- self.gr_progress = None
198
- self.model_version = self.cfg.version if hasattr(self.cfg, "version") else None
199
-
200
- @torch.no_grad()
201
- def get_emb(self, input_features, attention_mask):
202
- vq_emb = self.semantic_model(
203
- input_features=input_features,
204
- attention_mask=attention_mask,
205
- output_hidden_states=True,
206
- )
207
- feat = vq_emb.hidden_states[17] # (B, T, C)
208
- feat = (feat - self.semantic_mean) / self.semantic_std
209
- return feat
210
-
211
- def remove_long_silence(self, codes: torch.Tensor, silent_token=52, max_consecutive=30):
212
- """
213
- Shrink special tokens (silent_token and stop_mel_token) in codes
214
- codes: [B, T]
215
- """
216
- code_lens = []
217
- codes_list = []
218
- device = codes.device
219
- dtype = codes.dtype
220
- isfix = False
221
- for i in range(0, codes.shape[0]):
222
- code = codes[i]
223
- if not torch.any(code == self.stop_mel_token).item():
224
- len_ = code.size(0)
225
- else:
226
- stop_mel_idx = (code == self.stop_mel_token).nonzero(as_tuple=False)
227
- len_ = stop_mel_idx[0].item() if len(stop_mel_idx) > 0 else code.size(0)
228
-
229
- count = torch.sum(code == silent_token).item()
230
- if count > max_consecutive:
231
- # code = code.cpu().tolist()
232
- ncode_idx = []
233
- n = 0
234
- for k in range(len_):
235
- assert code[
236
- k] != self.stop_mel_token, f"stop_mel_token {self.stop_mel_token} should be shrinked here"
237
- if code[k] != silent_token:
238
- ncode_idx.append(k)
239
- n = 0
240
- elif code[k] == silent_token and n < 10:
241
- ncode_idx.append(k)
242
- n += 1
243
- # if (k == 0 and code[k] == 52) or (code[k] == 52 and code[k-1] == 52):
244
- # n += 1
245
- # new code
246
- len_ = len(ncode_idx)
247
- codes_list.append(code[ncode_idx])
248
- isfix = True
249
- else:
250
- # shrink to len_
251
- codes_list.append(code[:len_])
252
- code_lens.append(len_)
253
- if isfix:
254
- if len(codes_list) > 1:
255
- codes = pad_sequence(codes_list, batch_first=True, padding_value=self.stop_mel_token)
256
- else:
257
- codes = codes_list[0].unsqueeze(0)
258
- else:
259
- # unchanged
260
- pass
261
- # clip codes to max length
262
- max_len = max(code_lens)
263
- if max_len < codes.shape[1]:
264
- codes = codes[:, :max_len]
265
- code_lens = torch.tensor(code_lens, dtype=torch.long, device=device)
266
- return codes, code_lens
267
-
268
- def insert_interval_silence(self, wavs, sampling_rate=22050, interval_silence=200):
269
- """
270
- Insert silences between generated segments.
271
- wavs: List[torch.tensor]
272
- """
273
-
274
- if not wavs or interval_silence <= 0:
275
- return wavs
276
-
277
- # get channel_size
278
- channel_size = wavs[0].size(0)
279
- # get silence tensor
280
- sil_dur = int(sampling_rate * interval_silence / 1000.0)
281
- sil_tensor = torch.zeros(channel_size, sil_dur)
282
-
283
- wavs_list = []
284
- for i, wav in enumerate(wavs):
285
- wavs_list.append(wav)
286
- if i < len(wavs) - 1:
287
- wavs_list.append(sil_tensor)
288
-
289
- return wavs_list
290
-
291
- def _set_gr_progress(self, value, desc):
292
- if self.gr_progress is not None:
293
- self.gr_progress(value, desc=desc)
294
-
295
- def _load_and_cut_audio(self,audio_path,max_audio_length_seconds,verbose=False,sr=None):
296
- if not sr:
297
- audio, sr = librosa.load(audio_path)
298
- else:
299
- audio, _ = librosa.load(audio_path,sr=sr)
300
- audio = torch.tensor(audio).unsqueeze(0)
301
- max_audio_samples = int(max_audio_length_seconds * sr)
302
-
303
- if audio.shape[1] > max_audio_samples:
304
- if verbose:
305
- print(f"Audio too long ({audio.shape[1]} samples), truncating to {max_audio_samples} samples")
306
- audio = audio[:, :max_audio_samples]
307
- return audio, sr
308
-
309
- # 原始推理模式
310
- def infer(self, spk_audio_prompt, text, output_path,
311
- emo_audio_prompt=None, emo_alpha=1.0,
312
- emo_vector=None,
313
- use_emo_text=False, emo_text=None, use_random=False, interval_silence=200,
314
- verbose=False, max_text_tokens_per_segment=120, **generation_kwargs):
315
- print(">> starting inference...")
316
- self._set_gr_progress(0, "starting inference...")
317
- if verbose:
318
- print(f"origin text:{text}, spk_audio_prompt:{spk_audio_prompt}, "
319
- f"emo_audio_prompt:{emo_audio_prompt}, emo_alpha:{emo_alpha}, "
320
- f"emo_vector:{emo_vector}, use_emo_text:{use_emo_text}, "
321
- f"emo_text:{emo_text}")
322
- start_time = time.perf_counter()
323
-
324
- if use_emo_text or emo_vector is not None:
325
- # we're using a text or emotion vector guidance; so we must remove
326
- # "emotion reference voice", to ensure we use correct emotion mixing!
327
- emo_audio_prompt = None
328
-
329
- if use_emo_text:
330
- # automatically generate emotion vectors from text prompt
331
- if emo_text is None:
332
- emo_text = text # use main text prompt
333
- emo_dict = self.qwen_emo.inference(emo_text)
334
- print(f"detected emotion vectors from text: {emo_dict}")
335
- # convert ordered dict to list of vectors; the order is VERY important!
336
- emo_vector = list(emo_dict.values())
337
-
338
- if emo_vector is not None:
339
- # we have emotion vectors; they can't be blended via alpha mixing
340
- # in the main inference process later, so we must pre-calculate
341
- # their new strengths here based on the alpha instead!
342
- emo_vector_scale = max(0.0, min(1.0, emo_alpha))
343
- if emo_vector_scale != 1.0:
344
- # scale each vector and truncate to 4 decimals (for nicer printing)
345
- emo_vector = [int(x * emo_vector_scale * 10000) / 10000 for x in emo_vector]
346
- print(f"scaled emotion vectors to {emo_vector_scale}x: {emo_vector}")
347
-
348
- if emo_audio_prompt is None:
349
- # we are not using any external "emotion reference voice"; use
350
- # speaker's voice as the main emotion reference audio.
351
- emo_audio_prompt = spk_audio_prompt
352
- # must always use alpha=1.0 when we don't have an external reference voice
353
- emo_alpha = 1.0
354
-
355
- # 如果参考音频改变了,才需要重新生成, 提升速度
356
- if self.cache_spk_cond is None or self.cache_spk_audio_prompt != spk_audio_prompt:
357
- audio,sr = self._load_and_cut_audio(spk_audio_prompt,15,verbose)
358
- audio_22k = torchaudio.transforms.Resample(sr, 22050)(audio)
359
- audio_16k = torchaudio.transforms.Resample(sr, 16000)(audio)
360
-
361
- inputs = self.extract_features(audio_16k, sampling_rate=16000, return_tensors="pt")
362
- input_features = inputs["input_features"]
363
- attention_mask = inputs["attention_mask"]
364
- input_features = input_features.to(self.device)
365
- attention_mask = attention_mask.to(self.device)
366
- spk_cond_emb = self.get_emb(input_features, attention_mask)
367
-
368
- _, S_ref = self.semantic_codec.quantize(spk_cond_emb)
369
- ref_mel = self.mel_fn(audio_22k.to(spk_cond_emb.device).float())
370
- ref_target_lengths = torch.LongTensor([ref_mel.size(2)]).to(ref_mel.device)
371
- feat = torchaudio.compliance.kaldi.fbank(audio_16k.to(ref_mel.device),
372
- num_mel_bins=80,
373
- dither=0,
374
- sample_frequency=16000)
375
- feat = feat - feat.mean(dim=0, keepdim=True) # feat2另外一个滤波器能量组特征[922, 80]
376
- style = self.campplus_model(feat.unsqueeze(0)) # 参考音频的全局style2[1,192]
377
-
378
- prompt_condition = self.s2mel.models['length_regulator'](S_ref,
379
- ylens=ref_target_lengths,
380
- n_quantizers=3,
381
- f0=None)[0]
382
-
383
- self.cache_spk_cond = spk_cond_emb
384
- self.cache_s2mel_style = style
385
- self.cache_s2mel_prompt = prompt_condition
386
- self.cache_spk_audio_prompt = spk_audio_prompt
387
- self.cache_mel = ref_mel
388
- else:
389
- style = self.cache_s2mel_style
390
- prompt_condition = self.cache_s2mel_prompt
391
- spk_cond_emb = self.cache_spk_cond
392
- ref_mel = self.cache_mel
393
-
394
- if emo_vector is not None:
395
- weight_vector = torch.tensor(emo_vector).to(self.device)
396
- if use_random:
397
- random_index = [random.randint(0, x - 1) for x in self.emo_num]
398
- else:
399
- random_index = [find_most_similar_cosine(style, tmp) for tmp in self.spk_matrix]
400
-
401
- emo_matrix = [tmp[index].unsqueeze(0) for index, tmp in zip(random_index, self.emo_matrix)]
402
- emo_matrix = torch.cat(emo_matrix, 0)
403
- emovec_mat = weight_vector.unsqueeze(1) * emo_matrix
404
- emovec_mat = torch.sum(emovec_mat, 0)
405
- emovec_mat = emovec_mat.unsqueeze(0)
406
-
407
- if self.cache_emo_cond is None or self.cache_emo_audio_prompt != emo_audio_prompt:
408
- emo_audio, _ = self._load_and_cut_audio(emo_audio_prompt,15,verbose,sr=16000)
409
- emo_inputs = self.extract_features(emo_audio, sampling_rate=16000, return_tensors="pt")
410
- emo_input_features = emo_inputs["input_features"]
411
- emo_attention_mask = emo_inputs["attention_mask"]
412
- emo_input_features = emo_input_features.to(self.device)
413
- emo_attention_mask = emo_attention_mask.to(self.device)
414
- emo_cond_emb = self.get_emb(emo_input_features, emo_attention_mask)
415
-
416
- self.cache_emo_cond = emo_cond_emb
417
- self.cache_emo_audio_prompt = emo_audio_prompt
418
- else:
419
- emo_cond_emb = self.cache_emo_cond
420
-
421
- self._set_gr_progress(0.1, "text processing...")
422
- text_tokens_list = self.tokenizer.tokenize(text)
423
- segments = self.tokenizer.split_segments(text_tokens_list, max_text_tokens_per_segment)
424
- segments_count = len(segments)
425
- if verbose:
426
- print("text_tokens_list:", text_tokens_list)
427
- print("segments count:", segments_count)
428
- print("max_text_tokens_per_segment:", max_text_tokens_per_segment)
429
- print(*segments, sep="\n")
430
- do_sample = generation_kwargs.pop("do_sample", True)
431
- top_p = generation_kwargs.pop("top_p", 0.8)
432
- top_k = generation_kwargs.pop("top_k", 30)
433
- temperature = generation_kwargs.pop("temperature", 0.8)
434
- autoregressive_batch_size = 1
435
- length_penalty = generation_kwargs.pop("length_penalty", 0.0)
436
- num_beams = generation_kwargs.pop("num_beams", 3)
437
- repetition_penalty = generation_kwargs.pop("repetition_penalty", 10.0)
438
- max_mel_tokens = generation_kwargs.pop("max_mel_tokens", 1500)
439
- sampling_rate = 22050
440
-
441
- wavs = []
442
- gpt_gen_time = 0
443
- gpt_forward_time = 0
444
- s2mel_time = 0
445
- bigvgan_time = 0
446
- has_warned = False
447
- for seg_idx, sent in enumerate(segments):
448
- self._set_gr_progress(0.2 + 0.7 * seg_idx / segments_count,
449
- f"speech synthesis {seg_idx + 1}/{segments_count}...")
450
-
451
- text_tokens = self.tokenizer.convert_tokens_to_ids(sent)
452
- text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0)
453
- if verbose:
454
- print(text_tokens)
455
- print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
456
- # debug tokenizer
457
- text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist())
458
- print("text_token_syms is same as segment tokens", text_token_syms == sent)
459
-
460
- m_start_time = time.perf_counter()
461
- with torch.no_grad():
462
- with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
463
- emovec = self.gpt.merge_emovec(
464
- spk_cond_emb,
465
- emo_cond_emb,
466
- torch.tensor([spk_cond_emb.shape[-1]], device=text_tokens.device),
467
- torch.tensor([emo_cond_emb.shape[-1]], device=text_tokens.device),
468
- alpha=emo_alpha
469
- )
470
-
471
- if emo_vector is not None:
472
- emovec = emovec_mat + (1 - torch.sum(weight_vector)) * emovec
473
- # emovec = emovec_mat
474
-
475
- codes, speech_conditioning_latent = self.gpt.inference_speech(
476
- spk_cond_emb,
477
- text_tokens,
478
- emo_cond_emb,
479
- cond_lengths=torch.tensor([spk_cond_emb.shape[-1]], device=text_tokens.device),
480
- emo_cond_lengths=torch.tensor([emo_cond_emb.shape[-1]], device=text_tokens.device),
481
- emo_vec=emovec,
482
- do_sample=True,
483
- top_p=top_p,
484
- top_k=top_k,
485
- temperature=temperature,
486
- num_return_sequences=autoregressive_batch_size,
487
- length_penalty=length_penalty,
488
- num_beams=num_beams,
489
- repetition_penalty=repetition_penalty,
490
- max_generate_length=max_mel_tokens,
491
- **generation_kwargs
492
- )
493
-
494
- gpt_gen_time += time.perf_counter() - m_start_time
495
- if not has_warned and (codes[:, -1] != self.stop_mel_token).any():
496
- warnings.warn(
497
- f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). "
498
- f"Input text tokens: {text_tokens.shape[1]}. "
499
- f"Consider reducing `max_text_tokens_per_segment`({max_text_tokens_per_segment}) or increasing `max_mel_tokens`.",
500
- category=RuntimeWarning
501
- )
502
- has_warned = True
503
-
504
- code_lens = torch.tensor([codes.shape[-1]], device=codes.device, dtype=codes.dtype)
505
- # if verbose:
506
- # print(codes, type(codes))
507
- # print(f"codes shape: {codes.shape}, codes type: {codes.dtype}")
508
- # print(f"code len: {code_lens}")
509
-
510
- code_lens = []
511
- for code in codes:
512
- if self.stop_mel_token not in code:
513
- code_lens.append(len(code))
514
- code_len = len(code)
515
- else:
516
- len_ = (code == self.stop_mel_token).nonzero(as_tuple=False)[0] + 1
517
- code_len = len_ - 1
518
- code_lens.append(code_len)
519
- codes = codes[:, :code_len]
520
- code_lens = torch.LongTensor(code_lens)
521
- code_lens = code_lens.to(self.device)
522
- if verbose:
523
- print(codes, type(codes))
524
- print(f"fix codes shape: {codes.shape}, codes type: {codes.dtype}")
525
- print(f"code len: {code_lens}")
526
-
527
- m_start_time = time.perf_counter()
528
- use_speed = torch.zeros(spk_cond_emb.size(0)).to(spk_cond_emb.device).long()
529
- with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
530
- latent = self.gpt(
531
- speech_conditioning_latent,
532
- text_tokens,
533
- torch.tensor([text_tokens.shape[-1]], device=text_tokens.device),
534
- codes,
535
- torch.tensor([codes.shape[-1]], device=text_tokens.device),
536
- emo_cond_emb,
537
- cond_mel_lengths=torch.tensor([spk_cond_emb.shape[-1]], device=text_tokens.device),
538
- emo_cond_mel_lengths=torch.tensor([emo_cond_emb.shape[-1]], device=text_tokens.device),
539
- emo_vec=emovec,
540
- use_speed=use_speed,
541
- )
542
- gpt_forward_time += time.perf_counter() - m_start_time
543
-
544
- dtype = None
545
- with torch.amp.autocast(text_tokens.device.type, enabled=dtype is not None, dtype=dtype):
546
- m_start_time = time.perf_counter()
547
- diffusion_steps = 25
548
- inference_cfg_rate = 0.7
549
- latent = self.s2mel.models['gpt_layer'](latent)
550
- S_infer = self.semantic_codec.quantizer.vq2emb(codes.unsqueeze(1))
551
- S_infer = S_infer.transpose(1, 2)
552
- S_infer = S_infer + latent
553
- target_lengths = (code_lens * 1.72).long()
554
-
555
- cond = self.s2mel.models['length_regulator'](S_infer,
556
- ylens=target_lengths,
557
- n_quantizers=3,
558
- f0=None)[0]
559
- cat_condition = torch.cat([prompt_condition, cond], dim=1)
560
- vc_target = self.s2mel.models['cfm'].inference(cat_condition,
561
- torch.LongTensor([cat_condition.size(1)]).to(
562
- cond.device),
563
- ref_mel, style, None, diffusion_steps,
564
- inference_cfg_rate=inference_cfg_rate)
565
- vc_target = vc_target[:, :, ref_mel.size(-1):]
566
- s2mel_time += time.perf_counter() - m_start_time
567
-
568
- m_start_time = time.perf_counter()
569
- wav = self.bigvgan(vc_target.float()).squeeze().unsqueeze(0)
570
- print(wav.shape)
571
- bigvgan_time += time.perf_counter() - m_start_time
572
- wav = wav.squeeze(1)
573
-
574
- wav = torch.clamp(32767 * wav, -32767.0, 32767.0)
575
- if verbose:
576
- print(f"wav shape: {wav.shape}", "min:", wav.min(), "max:", wav.max())
577
- # wavs.append(wav[:, :-512])
578
- wavs.append(wav.cpu()) # to cpu before saving
579
- end_time = time.perf_counter()
580
-
581
- self._set_gr_progress(0.9, "saving audio...")
582
- wavs = self.insert_interval_silence(wavs, sampling_rate=sampling_rate, interval_silence=interval_silence)
583
- wav = torch.cat(wavs, dim=1)
584
- wav_length = wav.shape[-1] / sampling_rate
585
- print(f">> gpt_gen_time: {gpt_gen_time:.2f} seconds")
586
- print(f">> gpt_forward_time: {gpt_forward_time:.2f} seconds")
587
- print(f">> s2mel_time: {s2mel_time:.2f} seconds")
588
- print(f">> bigvgan_time: {bigvgan_time:.2f} seconds")
589
- print(f">> Total inference time: {end_time - start_time:.2f} seconds")
590
- print(f">> Generated audio length: {wav_length:.2f} seconds")
591
- print(f">> RTF: {(end_time - start_time) / wav_length:.4f}")
592
-
593
- # save audio
594
- wav = wav.cpu() # to cpu
595
- if output_path:
596
- # 直接保存音频到指定路径中
597
- if os.path.isfile(output_path):
598
- os.remove(output_path)
599
- print(">> remove old wav file:", output_path)
600
- if os.path.dirname(output_path) != "":
601
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
602
- torchaudio.save(output_path, wav.type(torch.int16), sampling_rate)
603
- print(">> wav file saved to:", output_path)
604
- return output_path
605
- else:
606
- # 返回以符合Gradio的格式要求
607
- wav_data = wav.type(torch.int16)
608
- wav_data = wav_data.numpy().T
609
- return (sampling_rate, wav_data)
610
-
611
-
612
- def find_most_similar_cosine(query_vector, matrix):
613
- query_vector = query_vector.float()
614
- matrix = matrix.float()
615
-
616
- similarities = F.cosine_similarity(query_vector, matrix, dim=1)
617
- most_similar_index = torch.argmax(similarities)
618
- return most_similar_index
619
-
620
- class QwenEmotion:
621
- def __init__(self, model_dir):
622
- self.model_dir = model_dir
623
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
624
- self.model = AutoModelForCausalLM.from_pretrained(
625
- self.model_dir,
626
- torch_dtype="float16", # "auto"
627
- device_map="auto"
628
- )
629
- self.prompt = "文本情感分类"
630
- self.cn_key_to_en = {
631
- "高兴": "happy",
632
- "愤怒": "angry",
633
- "悲伤": "sad",
634
- "恐惧": "afraid",
635
- "反感": "disgusted",
636
- # TODO: the "低落" (melancholic) emotion will always be mapped to
637
- # "悲伤" (sad) by QwenEmotion's text analysis. it doesn't know the
638
- # difference between those emotions even if user writes exact words.
639
- # SEE: `self.melancholic_words` for current workaround.
640
- "低落": "melancholic",
641
- "惊讶": "surprised",
642
- "自然": "calm",
643
- }
644
- self.desired_vector_order = ["高兴", "愤怒", "悲伤", "恐惧", "反感", "低落", "惊讶", "自然"]
645
- self.melancholic_words = {
646
- # emotion text phrases that will force QwenEmotion's "悲伤" (sad) detection
647
- # to become "低落" (melancholic) instead, to fix limitations mentioned above.
648
- "低落",
649
- "melancholy",
650
- "melancholic",
651
- "depression",
652
- "depressed",
653
- "gloomy",
654
- }
655
- self.max_score = 1.2
656
- self.min_score = 0.0
657
-
658
- def clamp_score(self, value):
659
- return max(self.min_score, min(self.max_score, value))
660
-
661
- def convert(self, content):
662
- # generate emotion vector dictionary:
663
- # - insert values in desired order (Python 3.7+ `dict` remembers insertion order)
664
- # - convert Chinese keys to English
665
- # - clamp all values to the allowed min/max range
666
- # - use 0.0 for any values that were missing in `content`
667
- emotion_dict = {
668
- self.cn_key_to_en[cn_key]: self.clamp_score(content.get(cn_key, 0.0))
669
- for cn_key in self.desired_vector_order
670
- }
671
-
672
- # default to a calm/neutral voice if all emotion vectors were empty
673
- if all(val <= 0.0 for val in emotion_dict.values()):
674
- print(">> no emotions detected; using default calm/neutral voice")
675
- emotion_dict["calm"] = 1.0
676
-
677
- return emotion_dict
678
-
679
- def inference(self, text_input):
680
- start = time.time()
681
- messages = [
682
- {"role": "system", "content": f"{self.prompt}"},
683
- {"role": "user", "content": f"{text_input}"}
684
- ]
685
- text = self.tokenizer.apply_chat_template(
686
- messages,
687
- tokenize=False,
688
- add_generation_prompt=True,
689
- enable_thinking=False,
690
- )
691
- model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
692
-
693
- # conduct text completion
694
- generated_ids = self.model.generate(
695
- **model_inputs,
696
- max_new_tokens=32768,
697
- pad_token_id=self.tokenizer.eos_token_id
698
- )
699
- output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
700
-
701
- # parsing thinking content
702
- try:
703
- # rindex finding 151668 (</think>)
704
- index = len(output_ids) - output_ids[::-1].index(151668)
705
- except ValueError:
706
- index = 0
707
-
708
- content = self.tokenizer.decode(output_ids[index:], skip_special_tokens=True)
709
-
710
- # decode the JSON emotion detections as a dictionary
711
- try:
712
- content = json.loads(content)
713
- except json.decoder.JSONDecodeError:
714
- # invalid JSON; fallback to manual string parsing
715
- # print(">> parsing QwenEmotion response", content)
716
- content = {
717
- m.group(1): float(m.group(2))
718
- for m in re.finditer(r'([^\s":.,]+?)"?\s*:\s*([\d.]+)', content)
719
- }
720
- # print(">> dict result", content)
721
-
722
- # workaround for QwenEmotion's inability to distinguish "悲伤" (sad) vs "低落" (melancholic).
723
- # if we detect any of the IndexTTS "melancholic" words, we swap those vectors
724
- # to encode the "sad" emotion as "melancholic" (instead of sadness).
725
- text_input_lower = text_input.lower()
726
- if any(word in text_input_lower for word in self.melancholic_words):
727
- # print(">> before vec swap", content)
728
- content["悲伤"], content["低落"] = content.get("低落", 0.0), content.get("悲伤", 0.0)
729
- # print(">> after vec swap", content)
730
-
731
- return self.convert(content)
732
-
733
-
734
- if __name__ == "__main__":
735
- prompt_wav = "examples/voice_01.wav"
736
- text = '欢迎大家来体验indextts2,并给予我们意见与反馈,谢谢大家。'
737
-
738
- tts = IndexTTS2(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_cuda_kernel=False)
739
- tts.infer(spk_audio_prompt=prompt_wav, text=text, output_path="gen.wav", verbose=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/s2mel/dac/__init__.py DELETED
@@ -1,16 +0,0 @@
1
- __version__ = "1.0.0"
2
-
3
- # preserved here for legacy reasons
4
- __model_version__ = "latest"
5
-
6
- import audiotools
7
-
8
- audiotools.ml.BaseModel.INTERN += ["dac.**"]
9
- audiotools.ml.BaseModel.EXTERN += ["einops"]
10
-
11
-
12
- from . import nn
13
- from . import model
14
- from . import utils
15
- from .model import DAC
16
- from .model import DACFile
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/s2mel/dac/__main__.py DELETED
@@ -1,36 +0,0 @@
1
- import sys
2
-
3
- import argbind
4
-
5
- from dac.utils import download
6
- from dac.utils.decode import decode
7
- from dac.utils.encode import encode
8
-
9
- STAGES = ["encode", "decode", "download"]
10
-
11
-
12
- def run(stage: str):
13
- """Run stages.
14
-
15
- Parameters
16
- ----------
17
- stage : str
18
- Stage to run
19
- """
20
- if stage not in STAGES:
21
- raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}")
22
- stage_fn = globals()[stage]
23
-
24
- if stage == "download":
25
- stage_fn()
26
- return
27
-
28
- stage_fn()
29
-
30
-
31
- if __name__ == "__main__":
32
- group = sys.argv.pop(1)
33
- args = argbind.parse_args(group=group)
34
-
35
- with argbind.scope(args):
36
- run(group)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/s2mel/dac/model/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- from .base import CodecMixin
2
- from .base import DACFile
3
- from .dac import DAC
4
- from .discriminator import Discriminator
 
 
 
 
 
indextts/s2mel/dac/model/base.py DELETED
@@ -1,294 +0,0 @@
1
- import math
2
- from dataclasses import dataclass
3
- from pathlib import Path
4
- from typing import Union
5
-
6
- import numpy as np
7
- import torch
8
- import tqdm
9
- from audiotools import AudioSignal
10
- from torch import nn
11
-
12
- SUPPORTED_VERSIONS = ["1.0.0"]
13
-
14
-
15
- @dataclass
16
- class DACFile:
17
- codes: torch.Tensor
18
-
19
- # Metadata
20
- chunk_length: int
21
- original_length: int
22
- input_db: float
23
- channels: int
24
- sample_rate: int
25
- padding: bool
26
- dac_version: str
27
-
28
- def save(self, path):
29
- artifacts = {
30
- "codes": self.codes.numpy().astype(np.uint16),
31
- "metadata": {
32
- "input_db": self.input_db.numpy().astype(np.float32),
33
- "original_length": self.original_length,
34
- "sample_rate": self.sample_rate,
35
- "chunk_length": self.chunk_length,
36
- "channels": self.channels,
37
- "padding": self.padding,
38
- "dac_version": SUPPORTED_VERSIONS[-1],
39
- },
40
- }
41
- path = Path(path).with_suffix(".dac")
42
- with open(path, "wb") as f:
43
- np.save(f, artifacts)
44
- return path
45
-
46
- @classmethod
47
- def load(cls, path):
48
- artifacts = np.load(path, allow_pickle=True)[()]
49
- codes = torch.from_numpy(artifacts["codes"].astype(int))
50
- if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
51
- raise RuntimeError(
52
- f"Given file {path} can't be loaded with this version of descript-audio-codec."
53
- )
54
- return cls(codes=codes, **artifacts["metadata"])
55
-
56
-
57
- class CodecMixin:
58
- @property
59
- def padding(self):
60
- if not hasattr(self, "_padding"):
61
- self._padding = True
62
- return self._padding
63
-
64
- @padding.setter
65
- def padding(self, value):
66
- assert isinstance(value, bool)
67
-
68
- layers = [
69
- l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))
70
- ]
71
-
72
- for layer in layers:
73
- if value:
74
- if hasattr(layer, "original_padding"):
75
- layer.padding = layer.original_padding
76
- else:
77
- layer.original_padding = layer.padding
78
- layer.padding = tuple(0 for _ in range(len(layer.padding)))
79
-
80
- self._padding = value
81
-
82
- def get_delay(self):
83
- # Any number works here, delay is invariant to input length
84
- l_out = self.get_output_length(0)
85
- L = l_out
86
-
87
- layers = []
88
- for layer in self.modules():
89
- if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
90
- layers.append(layer)
91
-
92
- for layer in reversed(layers):
93
- d = layer.dilation[0]
94
- k = layer.kernel_size[0]
95
- s = layer.stride[0]
96
-
97
- if isinstance(layer, nn.ConvTranspose1d):
98
- L = ((L - d * (k - 1) - 1) / s) + 1
99
- elif isinstance(layer, nn.Conv1d):
100
- L = (L - 1) * s + d * (k - 1) + 1
101
-
102
- L = math.ceil(L)
103
-
104
- l_in = L
105
-
106
- return (l_in - l_out) // 2
107
-
108
- def get_output_length(self, input_length):
109
- L = input_length
110
- # Calculate output length
111
- for layer in self.modules():
112
- if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
113
- d = layer.dilation[0]
114
- k = layer.kernel_size[0]
115
- s = layer.stride[0]
116
-
117
- if isinstance(layer, nn.Conv1d):
118
- L = ((L - d * (k - 1) - 1) / s) + 1
119
- elif isinstance(layer, nn.ConvTranspose1d):
120
- L = (L - 1) * s + d * (k - 1) + 1
121
-
122
- L = math.floor(L)
123
- return L
124
-
125
- @torch.no_grad()
126
- def compress(
127
- self,
128
- audio_path_or_signal: Union[str, Path, AudioSignal],
129
- win_duration: float = 1.0,
130
- verbose: bool = False,
131
- normalize_db: float = -16,
132
- n_quantizers: int = None,
133
- ) -> DACFile:
134
- """Processes an audio signal from a file or AudioSignal object into
135
- discrete codes. This function processes the signal in short windows,
136
- using constant GPU memory.
137
-
138
- Parameters
139
- ----------
140
- audio_path_or_signal : Union[str, Path, AudioSignal]
141
- audio signal to reconstruct
142
- win_duration : float, optional
143
- window duration in seconds, by default 5.0
144
- verbose : bool, optional
145
- by default False
146
- normalize_db : float, optional
147
- normalize db, by default -16
148
-
149
- Returns
150
- -------
151
- DACFile
152
- Object containing compressed codes and metadata
153
- required for decompression
154
- """
155
- audio_signal = audio_path_or_signal
156
- if isinstance(audio_signal, (str, Path)):
157
- audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
158
-
159
- self.eval()
160
- original_padding = self.padding
161
- original_device = audio_signal.device
162
-
163
- audio_signal = audio_signal.clone()
164
- original_sr = audio_signal.sample_rate
165
-
166
- resample_fn = audio_signal.resample
167
- loudness_fn = audio_signal.loudness
168
-
169
- # If audio is > 10 minutes long, use the ffmpeg versions
170
- if audio_signal.signal_duration >= 10 * 60 * 60:
171
- resample_fn = audio_signal.ffmpeg_resample
172
- loudness_fn = audio_signal.ffmpeg_loudness
173
-
174
- original_length = audio_signal.signal_length
175
- resample_fn(self.sample_rate)
176
- input_db = loudness_fn()
177
-
178
- if normalize_db is not None:
179
- audio_signal.normalize(normalize_db)
180
- audio_signal.ensure_max_of_audio()
181
-
182
- nb, nac, nt = audio_signal.audio_data.shape
183
- audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
184
- win_duration = (
185
- audio_signal.signal_duration if win_duration is None else win_duration
186
- )
187
-
188
- if audio_signal.signal_duration <= win_duration:
189
- # Unchunked compression (used if signal length < win duration)
190
- self.padding = True
191
- n_samples = nt
192
- hop = nt
193
- else:
194
- # Chunked inference
195
- self.padding = False
196
- # Zero-pad signal on either side by the delay
197
- audio_signal.zero_pad(self.delay, self.delay)
198
- n_samples = int(win_duration * self.sample_rate)
199
- # Round n_samples to nearest hop length multiple
200
- n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
201
- hop = self.get_output_length(n_samples)
202
-
203
- codes = []
204
- range_fn = range if not verbose else tqdm.trange
205
-
206
- for i in range_fn(0, nt, hop):
207
- x = audio_signal[..., i : i + n_samples]
208
- x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
209
-
210
- audio_data = x.audio_data.to(self.device)
211
- audio_data = self.preprocess(audio_data, self.sample_rate)
212
- _, c, _, _, _ = self.encode(audio_data, n_quantizers)
213
- codes.append(c.to(original_device))
214
- chunk_length = c.shape[-1]
215
-
216
- codes = torch.cat(codes, dim=-1)
217
-
218
- dac_file = DACFile(
219
- codes=codes,
220
- chunk_length=chunk_length,
221
- original_length=original_length,
222
- input_db=input_db,
223
- channels=nac,
224
- sample_rate=original_sr,
225
- padding=self.padding,
226
- dac_version=SUPPORTED_VERSIONS[-1],
227
- )
228
-
229
- if n_quantizers is not None:
230
- codes = codes[:, :n_quantizers, :]
231
-
232
- self.padding = original_padding
233
- return dac_file
234
-
235
- @torch.no_grad()
236
- def decompress(
237
- self,
238
- obj: Union[str, Path, DACFile],
239
- verbose: bool = False,
240
- ) -> AudioSignal:
241
- """Reconstruct audio from a given .dac file
242
-
243
- Parameters
244
- ----------
245
- obj : Union[str, Path, DACFile]
246
- .dac file location or corresponding DACFile object.
247
- verbose : bool, optional
248
- Prints progress if True, by default False
249
-
250
- Returns
251
- -------
252
- AudioSignal
253
- Object with the reconstructed audio
254
- """
255
- self.eval()
256
- if isinstance(obj, (str, Path)):
257
- obj = DACFile.load(obj)
258
-
259
- original_padding = self.padding
260
- self.padding = obj.padding
261
-
262
- range_fn = range if not verbose else tqdm.trange
263
- codes = obj.codes
264
- original_device = codes.device
265
- chunk_length = obj.chunk_length
266
- recons = []
267
-
268
- for i in range_fn(0, codes.shape[-1], chunk_length):
269
- c = codes[..., i : i + chunk_length].to(self.device)
270
- z = self.quantizer.from_codes(c)[0]
271
- r = self.decode(z)
272
- recons.append(r.to(original_device))
273
-
274
- recons = torch.cat(recons, dim=-1)
275
- recons = AudioSignal(recons, self.sample_rate)
276
-
277
- resample_fn = recons.resample
278
- loudness_fn = recons.loudness
279
-
280
- # If audio is > 10 minutes long, use the ffmpeg versions
281
- if recons.signal_duration >= 10 * 60 * 60:
282
- resample_fn = recons.ffmpeg_resample
283
- loudness_fn = recons.ffmpeg_loudness
284
-
285
- recons.normalize(obj.input_db)
286
- resample_fn(obj.sample_rate)
287
- recons = recons[..., : obj.original_length]
288
- loudness_fn()
289
- recons.audio_data = recons.audio_data.reshape(
290
- -1, obj.channels, obj.original_length
291
- )
292
-
293
- self.padding = original_padding
294
- return recons
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/s2mel/dac/model/dac.py DELETED
@@ -1,400 +0,0 @@
1
- import math
2
- from typing import List
3
- from typing import Union
4
-
5
- import numpy as np
6
- import torch
7
- from audiotools import AudioSignal
8
- from audiotools.ml import BaseModel
9
- from torch import nn
10
-
11
- from .base import CodecMixin
12
- from indextts.s2mel.dac.nn.layers import Snake1d
13
- from indextts.s2mel.dac.nn.layers import WNConv1d
14
- from indextts.s2mel.dac.nn.layers import WNConvTranspose1d
15
- from indextts.s2mel.dac.nn.quantize import ResidualVectorQuantize
16
- from .encodec import SConv1d, SConvTranspose1d, SLSTM
17
-
18
-
19
- def init_weights(m):
20
- if isinstance(m, nn.Conv1d):
21
- nn.init.trunc_normal_(m.weight, std=0.02)
22
- nn.init.constant_(m.bias, 0)
23
-
24
-
25
- class ResidualUnit(nn.Module):
26
- def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False):
27
- super().__init__()
28
- conv1d_type = SConv1d# if causal else WNConv1d
29
- pad = ((7 - 1) * dilation) // 2
30
- self.block = nn.Sequential(
31
- Snake1d(dim),
32
- conv1d_type(dim, dim, kernel_size=7, dilation=dilation, padding=pad, causal=causal, norm='weight_norm'),
33
- Snake1d(dim),
34
- conv1d_type(dim, dim, kernel_size=1, causal=causal, norm='weight_norm'),
35
- )
36
-
37
- def forward(self, x):
38
- y = self.block(x)
39
- pad = (x.shape[-1] - y.shape[-1]) // 2
40
- if pad > 0:
41
- x = x[..., pad:-pad]
42
- return x + y
43
-
44
-
45
- class EncoderBlock(nn.Module):
46
- def __init__(self, dim: int = 16, stride: int = 1, causal: bool = False):
47
- super().__init__()
48
- conv1d_type = SConv1d# if causal else WNConv1d
49
- self.block = nn.Sequential(
50
- ResidualUnit(dim // 2, dilation=1, causal=causal),
51
- ResidualUnit(dim // 2, dilation=3, causal=causal),
52
- ResidualUnit(dim // 2, dilation=9, causal=causal),
53
- Snake1d(dim // 2),
54
- conv1d_type(
55
- dim // 2,
56
- dim,
57
- kernel_size=2 * stride,
58
- stride=stride,
59
- padding=math.ceil(stride / 2),
60
- causal=causal,
61
- norm='weight_norm',
62
- ),
63
- )
64
-
65
- def forward(self, x):
66
- return self.block(x)
67
-
68
-
69
- class Encoder(nn.Module):
70
- def __init__(
71
- self,
72
- d_model: int = 64,
73
- strides: list = [2, 4, 8, 8],
74
- d_latent: int = 64,
75
- causal: bool = False,
76
- lstm: int = 2,
77
- ):
78
- super().__init__()
79
- conv1d_type = SConv1d# if causal else WNConv1d
80
- # Create first convolution
81
- self.block = [conv1d_type(1, d_model, kernel_size=7, padding=3, causal=causal, norm='weight_norm')]
82
-
83
- # Create EncoderBlocks that double channels as they downsample by `stride`
84
- for stride in strides:
85
- d_model *= 2
86
- self.block += [EncoderBlock(d_model, stride=stride, causal=causal)]
87
-
88
- # Add LSTM if needed
89
- self.use_lstm = lstm
90
- if lstm:
91
- self.block += [SLSTM(d_model, lstm)]
92
-
93
- # Create last convolution
94
- self.block += [
95
- Snake1d(d_model),
96
- conv1d_type(d_model, d_latent, kernel_size=3, padding=1, causal=causal, norm='weight_norm'),
97
- ]
98
-
99
- # Wrap black into nn.Sequential
100
- self.block = nn.Sequential(*self.block)
101
- self.enc_dim = d_model
102
-
103
- def forward(self, x):
104
- return self.block(x)
105
-
106
- def reset_cache(self):
107
- # recursively find all submodules named SConv1d in self.block and use their reset_cache method
108
- def reset_cache(m):
109
- if isinstance(m, SConv1d) or isinstance(m, SLSTM):
110
- m.reset_cache()
111
- return
112
- for child in m.children():
113
- reset_cache(child)
114
-
115
- reset_cache(self.block)
116
-
117
-
118
- class DecoderBlock(nn.Module):
119
- def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, causal: bool = False):
120
- super().__init__()
121
- conv1d_type = SConvTranspose1d #if causal else WNConvTranspose1d
122
- self.block = nn.Sequential(
123
- Snake1d(input_dim),
124
- conv1d_type(
125
- input_dim,
126
- output_dim,
127
- kernel_size=2 * stride,
128
- stride=stride,
129
- padding=math.ceil(stride / 2),
130
- causal=causal,
131
- norm='weight_norm'
132
- ),
133
- ResidualUnit(output_dim, dilation=1, causal=causal),
134
- ResidualUnit(output_dim, dilation=3, causal=causal),
135
- ResidualUnit(output_dim, dilation=9, causal=causal),
136
- )
137
-
138
- def forward(self, x):
139
- return self.block(x)
140
-
141
-
142
- class Decoder(nn.Module):
143
- def __init__(
144
- self,
145
- input_channel,
146
- channels,
147
- rates,
148
- d_out: int = 1,
149
- causal: bool = False,
150
- lstm: int = 2,
151
- ):
152
- super().__init__()
153
- conv1d_type = SConv1d# if causal else WNConv1d
154
- # Add first conv layer
155
- layers = [conv1d_type(input_channel, channels, kernel_size=7, padding=3, causal=causal, norm='weight_norm')]
156
-
157
- if lstm:
158
- layers += [SLSTM(channels, num_layers=lstm)]
159
-
160
- # Add upsampling + MRF blocks
161
- for i, stride in enumerate(rates):
162
- input_dim = channels // 2**i
163
- output_dim = channels // 2 ** (i + 1)
164
- layers += [DecoderBlock(input_dim, output_dim, stride, causal=causal)]
165
-
166
- # Add final conv layer
167
- layers += [
168
- Snake1d(output_dim),
169
- conv1d_type(output_dim, d_out, kernel_size=7, padding=3, causal=causal, norm='weight_norm'),
170
- nn.Tanh(),
171
- ]
172
-
173
- self.model = nn.Sequential(*layers)
174
-
175
- def forward(self, x):
176
- return self.model(x)
177
-
178
-
179
- class DAC(BaseModel, CodecMixin):
180
- def __init__(
181
- self,
182
- encoder_dim: int = 64,
183
- encoder_rates: List[int] = [2, 4, 8, 8],
184
- latent_dim: int = None,
185
- decoder_dim: int = 1536,
186
- decoder_rates: List[int] = [8, 8, 4, 2],
187
- n_codebooks: int = 9,
188
- codebook_size: int = 1024,
189
- codebook_dim: Union[int, list] = 8,
190
- quantizer_dropout: bool = False,
191
- sample_rate: int = 44100,
192
- lstm: int = 2,
193
- causal: bool = False,
194
- ):
195
- super().__init__()
196
-
197
- self.encoder_dim = encoder_dim
198
- self.encoder_rates = encoder_rates
199
- self.decoder_dim = decoder_dim
200
- self.decoder_rates = decoder_rates
201
- self.sample_rate = sample_rate
202
-
203
- if latent_dim is None:
204
- latent_dim = encoder_dim * (2 ** len(encoder_rates))
205
-
206
- self.latent_dim = latent_dim
207
-
208
- self.hop_length = np.prod(encoder_rates)
209
- self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim, causal=causal, lstm=lstm)
210
-
211
- self.n_codebooks = n_codebooks
212
- self.codebook_size = codebook_size
213
- self.codebook_dim = codebook_dim
214
- self.quantizer = ResidualVectorQuantize(
215
- input_dim=latent_dim,
216
- n_codebooks=n_codebooks,
217
- codebook_size=codebook_size,
218
- codebook_dim=codebook_dim,
219
- quantizer_dropout=quantizer_dropout,
220
- )
221
-
222
- self.decoder = Decoder(
223
- latent_dim,
224
- decoder_dim,
225
- decoder_rates,
226
- lstm=lstm,
227
- causal=causal,
228
- )
229
- self.sample_rate = sample_rate
230
- self.apply(init_weights)
231
-
232
- self.delay = self.get_delay()
233
-
234
- def preprocess(self, audio_data, sample_rate):
235
- if sample_rate is None:
236
- sample_rate = self.sample_rate
237
- assert sample_rate == self.sample_rate
238
-
239
- length = audio_data.shape[-1]
240
- right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
241
- audio_data = nn.functional.pad(audio_data, (0, right_pad))
242
-
243
- return audio_data
244
-
245
- def encode(
246
- self,
247
- audio_data: torch.Tensor,
248
- n_quantizers: int = None,
249
- ):
250
- """Encode given audio data and return quantized latent codes
251
-
252
- Parameters
253
- ----------
254
- audio_data : Tensor[B x 1 x T]
255
- Audio data to encode
256
- n_quantizers : int, optional
257
- Number of quantizers to use, by default None
258
- If None, all quantizers are used.
259
-
260
- Returns
261
- -------
262
- dict
263
- A dictionary with the following keys:
264
- "z" : Tensor[B x D x T]
265
- Quantized continuous representation of input
266
- "codes" : Tensor[B x N x T]
267
- Codebook indices for each codebook
268
- (quantized discrete representation of input)
269
- "latents" : Tensor[B x N*D x T]
270
- Projected latents (continuous representation of input before quantization)
271
- "vq/commitment_loss" : Tensor[1]
272
- Commitment loss to train encoder to predict vectors closer to codebook
273
- entries
274
- "vq/codebook_loss" : Tensor[1]
275
- Codebook loss to update the codebook
276
- "length" : int
277
- Number of samples in input audio
278
- """
279
- z = self.encoder(audio_data)
280
- z, codes, latents, commitment_loss, codebook_loss = self.quantizer(
281
- z, n_quantizers
282
- )
283
- return z, codes, latents, commitment_loss, codebook_loss
284
-
285
- def decode(self, z: torch.Tensor):
286
- """Decode given latent codes and return audio data
287
-
288
- Parameters
289
- ----------
290
- z : Tensor[B x D x T]
291
- Quantized continuous representation of input
292
- length : int, optional
293
- Number of samples in output audio, by default None
294
-
295
- Returns
296
- -------
297
- dict
298
- A dictionary with the following keys:
299
- "audio" : Tensor[B x 1 x length]
300
- Decoded audio data.
301
- """
302
- return self.decoder(z)
303
-
304
- def forward(
305
- self,
306
- audio_data: torch.Tensor,
307
- sample_rate: int = None,
308
- n_quantizers: int = None,
309
- ):
310
- """Model forward pass
311
-
312
- Parameters
313
- ----------
314
- audio_data : Tensor[B x 1 x T]
315
- Audio data to encode
316
- sample_rate : int, optional
317
- Sample rate of audio data in Hz, by default None
318
- If None, defaults to `self.sample_rate`
319
- n_quantizers : int, optional
320
- Number of quantizers to use, by default None.
321
- If None, all quantizers are used.
322
-
323
- Returns
324
- -------
325
- dict
326
- A dictionary with the following keys:
327
- "z" : Tensor[B x D x T]
328
- Quantized continuous representation of input
329
- "codes" : Tensor[B x N x T]
330
- Codebook indices for each codebook
331
- (quantized discrete representation of input)
332
- "latents" : Tensor[B x N*D x T]
333
- Projected latents (continuous representation of input before quantization)
334
- "vq/commitment_loss" : Tensor[1]
335
- Commitment loss to train encoder to predict vectors closer to codebook
336
- entries
337
- "vq/codebook_loss" : Tensor[1]
338
- Codebook loss to update the codebook
339
- "length" : int
340
- Number of samples in input audio
341
- "audio" : Tensor[B x 1 x length]
342
- Decoded audio data.
343
- """
344
- length = audio_data.shape[-1]
345
- audio_data = self.preprocess(audio_data, sample_rate)
346
- z, codes, latents, commitment_loss, codebook_loss = self.encode(
347
- audio_data, n_quantizers
348
- )
349
-
350
- x = self.decode(z)
351
- return {
352
- "audio": x[..., :length],
353
- "z": z,
354
- "codes": codes,
355
- "latents": latents,
356
- "vq/commitment_loss": commitment_loss,
357
- "vq/codebook_loss": codebook_loss,
358
- }
359
-
360
-
361
- if __name__ == "__main__":
362
- import numpy as np
363
- from functools import partial
364
-
365
- model = DAC().to("cpu")
366
-
367
- for n, m in model.named_modules():
368
- o = m.extra_repr()
369
- p = sum([np.prod(p.size()) for p in m.parameters()])
370
- fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
371
- setattr(m, "extra_repr", partial(fn, o=o, p=p))
372
- print(model)
373
- print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
374
-
375
- length = 88200 * 2
376
- x = torch.randn(1, 1, length).to(model.device)
377
- x.requires_grad_(True)
378
- x.retain_grad()
379
-
380
- # Make a forward pass
381
- out = model(x)["audio"]
382
- print("Input shape:", x.shape)
383
- print("Output shape:", out.shape)
384
-
385
- # Create gradient variable
386
- grad = torch.zeros_like(out)
387
- grad[:, :, grad.shape[-1] // 2] = 1
388
-
389
- # Make a backward pass
390
- out.backward(grad)
391
-
392
- # Check non-zero values
393
- gradmap = x.grad.squeeze(0)
394
- gradmap = (gradmap != 0).sum(0) # sum across features
395
- rf = (gradmap != 0).sum()
396
-
397
- print(f"Receptive field: {rf.item()}")
398
-
399
- x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
400
- model.decompress(model.compress(x, verbose=True), verbose=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
indextts/s2mel/dac/model/discriminator.py DELETED
@@ -1,228 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from audiotools import AudioSignal
5
- from audiotools import ml
6
- from audiotools import STFTParams
7
- from einops import rearrange
8
- from torch.nn.utils import weight_norm
9
-
10
-
11
- def WNConv1d(*args, **kwargs):
12
- act = kwargs.pop("act", True)
13
- conv = weight_norm(nn.Conv1d(*args, **kwargs))
14
- if not act:
15
- return conv
16
- return nn.Sequential(conv, nn.LeakyReLU(0.1))
17
-
18
-
19
- def WNConv2d(*args, **kwargs):
20
- act = kwargs.pop("act", True)
21
- conv = weight_norm(nn.Conv2d(*args, **kwargs))
22
- if not act:
23
- return conv
24
- return nn.Sequential(conv, nn.LeakyReLU(0.1))
25
-
26
-
27
- class MPD(nn.Module):
28
- def __init__(self, period):
29
- super().__init__()
30
- self.period = period
31
- self.convs = nn.ModuleList(
32
- [
33
- WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
34
- WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
35
- WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
36
- WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
37
- WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
38
- ]
39
- )
40
- self.conv_post = WNConv2d(
41
- 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
42
- )
43
-
44
- def pad_to_period(self, x):
45
- t = x.shape[-1]
46
- x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
47
- return x
48
-
49
- def forward(self, x):
50
- fmap = []
51
-
52
- x = self.pad_to_period(x)
53
- x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
54
-
55
- for layer in self.convs:
56
- x = layer(x)
57
- fmap.append(x)
58
-
59
- x = self.conv_post(x)
60
- fmap.append(x)
61
-
62
- return fmap
63
-
64
-
65
- class MSD(nn.Module):
66
- def __init__(self, rate: int = 1, sample_rate: int = 44100):
67
- super().__init__()
68
- self.convs = nn.ModuleList(
69
- [
70
- WNConv1d(1, 16, 15, 1, padding=7),
71
- WNConv1d(16, 64, 41, 4, groups=4, padding=20),
72
- WNConv1d(64, 256, 41, 4, groups=16, padding=20),
73
- WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
74
- WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
75
- WNConv1d(1024, 1024, 5, 1, padding=2),
76
- ]
77
- )
78
- self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
79
- self.sample_rate = sample_rate
80
- self.rate = rate
81
-
82
- def forward(self, x):
83
- x = AudioSignal(x, self.sample_rate)
84
- x.resample(self.sample_rate // self.rate)
85
- x = x.audio_data
86
-
87
- fmap = []
88
-
89
- for l in self.convs:
90
- x = l(x)
91
- fmap.append(x)
92
- x = self.conv_post(x)
93
- fmap.append(x)
94
-
95
- return fmap
96
-
97
-
98
- BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
99
-
100
-
101
- class MRD(nn.Module):
102
- def __init__(
103
- self,
104
- window_length: int,
105
- hop_factor: float = 0.25,
106
- sample_rate: int = 44100,
107
- bands: list = BANDS,
108
- ):
109
- """Complex multi-band spectrogram discriminator.
110
- Parameters
111
- ----------
112
- window_length : int
113
- Window length of STFT.
114
- hop_factor : float, optional
115
- Hop factor of the STFT, defaults to ``0.25 * window_length``.
116
- sample_rate : int, optional
117
- Sampling rate of audio in Hz, by default 44100
118
- bands : list, optional
119
- Bands to run discriminator over.
120
- """
121
- super().__init__()
122
-
123
- self.window_length = window_length
124
- self.hop_factor = hop_factor
125
- self.sample_rate = sample_rate
126
- self.stft_params = STFTParams(
127
- window_length=window_length,
128
- hop_length=int(window_length * hop_factor),
129
- match_stride=True,
130
- )
131
-
132
- n_fft = window_length // 2 + 1
133
- bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
134
- self.bands = bands
135
-
136
- ch = 32
137
- convs = lambda: nn.ModuleList(
138
- [
139
- WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
140
- WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
141
- WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
142
- WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
143
- WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
144
- ]
145
- )
146
- self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
147
- self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
148
-
149
- def spectrogram(self, x):
150
- x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
151
- x = torch.view_as_real(x.stft())
152
- x = rearrange(x, "b 1 f t c -> (b 1) c t f")
153
- # Split into bands
154
- x_bands = [x[..., b[0] : b[1]] for b in self.bands]
155
- return x_bands
156
-
157
- def forward(self, x):
158
- x_bands = self.spectrogram(x)
159
- fmap = []
160
-
161
- x = []
162
- for band, stack in zip(x_bands, self.band_convs):
163
- for layer in stack:
164
- band = layer(band)
165
- fmap.append(band)
166
- x.append(band)
167
-
168
- x = torch.cat(x, dim=-1)
169
- x = self.conv_post(x)
170
- fmap.append(x)
171
-
172
- return fmap
173
-
174
-
175
- class Discriminator(nn.Module):
176
- def __init__(
177
- self,
178
- rates: list = [],
179
- periods: list = [2, 3, 5, 7, 11],
180
- fft_sizes: list = [2048, 1024, 512],
181
- sample_rate: int = 44100,
182
- bands: list = BANDS,
183
- ):
184
- """Discriminator that combines multiple discriminators.
185
-
186
- Parameters
187
- ----------
188
- rates : list, optional
189
- sampling rates (in Hz) to run MSD at, by default []
190
- If empty, MSD is not used.
191
- periods : list, optional
192
- periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
193
- fft_sizes : list, optional
194
- Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
195
- sample_rate : int, optional
196
- Sampling rate of audio in Hz, by default 44100
197
- bands : list, optional
198
- Bands to run MRD at, by default `BANDS`
199
- """
200
- super().__init__()
201
- discs = []
202
- discs += [MPD(p) for p in periods]
203
- discs += [MSD(r, sample_rate=sample_rate) for r in rates]
204
- discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes]
205
- self.discriminators = nn.ModuleList(discs)
206
-
207
- def preprocess(self, y):
208
- # Remove DC offset
209
- y = y - y.mean(dim=-1, keepdims=True)
210
- # Peak normalize the volume of input audio
211
- y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
212
- return y
213
-
214
- def forward(self, x):
215
- x = self.preprocess(x)
216
- fmaps = [d(x) for d in self.discriminators]
217
- return fmaps
218
-
219
-
220
- if __name__ == "__main__":
221
- disc = Discriminator()
222
- x = torch.zeros(1, 1, 44100)
223
- results = disc(x)
224
- for i, result in enumerate(results):
225
- print(f"disc{i}")
226
- for i, r in enumerate(result):
227
- print(r.shape, r.mean(), r.min(), r.max())
228
- print()