IvA commited on
Commit
7cf31e7
·
1 Parent(s): db6d15c
Files changed (5) hide show
  1. .gitattributes +2 -0
  2. README.md +415 -3
  3. config.json +2 -97
  4. true.wav +3 -0
  5. video.mp4 +3 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ true.wav filter=lfs diff=lfs merge=lfs -text
37
+ video.mp4 filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,415 @@
1
- ---
2
- license: cc-by-4.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-4.0
3
+ language:
4
+ - en
5
+ tags:
6
+ - audio
7
+ - speech
8
+ - tokenizer
9
+ - vocoder
10
+ base_model:
11
+ - kyutai/mimi
12
+ ---
13
+
14
+
15
+ ## Attentionless VOcoder Streaming
16
+
17
+ <video width="1280" height="720" controls style="box-shadow: 0px 0px 20px 10px rgba(0, 0, 0, 0.05), 0px 1px 3px 10px rgba(255, 255, 255, 0.05);">
18
+ <source src="https://huggingface.co/ivao0/_AvoS/resolve/main/video.mp4" type="video/mp4">
19
+ Your browser does not support the video tag.
20
+ </video>
21
+
22
+ ## Usage
23
+
24
+ ```python
25
+ from huggingface_hub import hf_hub_download
26
+ import soundfile
27
+ import torch
28
+ from transformers import Wav2Vec2PreTrainedModel, PretrainedConfig
29
+ from torch import nn
30
+ import torch.nn.functional as F
31
+
32
+
33
+
34
+ class Voc(Wav2Vec2PreTrainedModel):
35
+
36
+ '''For using different batch_siz -> Voc._flush()
37
+ '''
38
+
39
+ def __init__(self,
40
+ config=PretrainedConfig(), n_q=18):
41
+
42
+ super().__init__(config=config)
43
+ self.encoder_transformer = VocTransformer()
44
+ self.decoder_transformer = VocTransformer()
45
+ self.encoder = SEANetEncoder()
46
+ self.decoder = SEANetDecoder()
47
+ self.sample_rate = 24000
48
+ self.quantizer = SplitResidualVectorQuantizer(n_q=n_q)
49
+ self.downsample = BufferConv1d(512, 512, kernel_size=4, stride=2, groups=1, bias=False)
50
+ upsample_channel_wise_bug = True
51
+ self.upsample = BufferConvTranspose1d(512, 512, kernel_size=4,
52
+ groups=512 if upsample_channel_wise_bug else 1,
53
+ stride=2, bias=False)
54
+ self.frame_rate = 12.5
55
+ self.encode_buffer = None # holds raw audio chunk if incomplete < 1920 samples
56
+
57
+ @torch.no_grad()
58
+ def _flush(self):
59
+ '''stream buffers have tensors of old batch size! Voc()._flush() to clean buffers
60
+ '''
61
+ self.encode_buffer = None # holds unused (incomplete windows of len < 1920) - we need 1920 to produce 1 token
62
+ if self.downsample.previous is not None:
63
+ self.downsample.previous = None
64
+ if self.upsample.partial is not None:
65
+ self.upsample.partial = None
66
+ for arch in [self.encoder, self.decoder]:
67
+ for _m in arch.model:
68
+ if type(_m) is SEANetResnetBlock:
69
+ for _b in _m.block:
70
+ if type(_b) is BufferConv1d:
71
+ if _b.previous is not None:
72
+ _b.previous = None
73
+ if type(_m) is BufferConv1d:
74
+ if _m.previous is not None:
75
+ _m.previous = None
76
+ if type(_m) is BufferConvTranspose1d:
77
+ if _m.partial is not None:
78
+ _m.partial = None
79
+
80
+ @torch.no_grad()
81
+ def encode(self, x):
82
+ '''24KHz audio to codes
83
+ x : [bs, 1, 24 KHz]
84
+ c : [bs, 8, time] = 1920 audio samples produce 1 time frame (of n_q codebooks)
85
+ '''
86
+ if self.encode_buffer is not None:
87
+ x = torch.cat([self.encode_buffer, x], 2)
88
+ _bs, _1, _len = x.shape
89
+ num_frames = int(_len / 1920)
90
+ leftover = x[:, :, (num_frames+1) * 1920:]
91
+ if leftover.shape[2] > 0:
92
+ self.encode_buffer = leftover
93
+ else:
94
+ self.encode_buffer = None
95
+ torch.cuda.empty_cache()
96
+ if num_frames > 0:
97
+ c = []
98
+ for n in range(num_frames):
99
+ e = self.encoder(x[:, :, n * 1920:(n + 1) * 1920])
100
+ e = self.encoder_transformer(e)
101
+ e = self.downsample(e)
102
+ _c = self.quantizer.encode(e)
103
+ c.append(_c)
104
+ c = torch.cat(c, 2)
105
+ else:
106
+ # num_frames = 0 Early exit -> for x.shape[2]<1920 fill conv buffers but can't output token
107
+ c = torch.empty(_bs, 0, self.n_q)
108
+ return c
109
+
110
+ @torch.no_grad()
111
+ def decode(self, c):
112
+ '''codes to 24kHZ audio
113
+ c: [bs, 8, n_tokens]
114
+ x: [bs, 1, n_tokens * 1920]
115
+ '''
116
+ _hidden = []
117
+ for i in range(c.shape[2]):
118
+ x = self.quantizer.decode(c[:, :, i:i+1])
119
+ x = self.upsample(x)
120
+ x = self.decoder_transformer(x)
121
+ x = self.decoder(x)
122
+ _hidden.append(x)
123
+ return torch.cat(_hidden, 2) # [bs, 1, 24KHz]
124
+
125
+
126
+ class SEANetResnetBlock(nn.Module):
127
+ def __init__(
128
+ self,
129
+ dim,
130
+ kernel_sizes=[3, 1],
131
+ ):
132
+ super().__init__()
133
+
134
+ block = []
135
+ for i, kernel_size in enumerate(kernel_sizes):
136
+
137
+ block += [
138
+ nn.ELU(),
139
+ BufferConv1d(
140
+ dim if i == 0 else dim // 2,
141
+ dim // 2 if i == 0 else dim,
142
+ kernel_size=kernel_size,
143
+ bias=True,
144
+ ),
145
+ ]
146
+
147
+ self.block = nn.Sequential(*block)
148
+
149
+ def forward(self, x):
150
+ return x + self.block(x) # BufferConv1d assures atleast 1 kernl exists 0pad or previous
151
+
152
+
153
+ class SEANetEncoder(nn.Module):
154
+ def __init__(
155
+ self,
156
+ channels=1, # DOES NOT SUPPORT STEREO
157
+ dimension=512,
158
+ n_filters=64,
159
+ ratios=[8, 6, 5, 4],
160
+ kernel_size=7,
161
+ last_kernel_size=3,
162
+ ):
163
+ super().__init__()
164
+ self.ratios = list(reversed(ratios))
165
+ del ratios
166
+ mult = 1 # incr. each of for
167
+ model=[
168
+ BufferConv1d(
169
+ channels,
170
+ mult * n_filters,
171
+ kernel_size,
172
+ bias=True
173
+ )
174
+ ]
175
+ for i, ratio in enumerate(self.ratios):
176
+ model += [SEANetResnetBlock(mult * n_filters),
177
+ nn.ELU(),
178
+ BufferConv1d(mult * n_filters,
179
+ mult * n_filters * 2,
180
+ kernel_size=ratio * 2,
181
+ stride=ratio,
182
+ bias=True)]
183
+ mult *= 2
184
+ # ENDFOR
185
+ model += [nn.ELU(),
186
+ BufferConv1d(mult * n_filters,
187
+ dimension,
188
+ last_kernel_size,
189
+ bias=True)]
190
+ self.model = nn.Sequential(*model)
191
+
192
+ def forward(self, x):
193
+ return self.model(x)
194
+
195
+
196
+ class SEANetDecoder(nn.Module):
197
+
198
+ def __init__(
199
+ self,
200
+ channels=1,
201
+ dimension=512,
202
+ n_filters=64,
203
+ ratios=[8, 6, 5, 4],
204
+ kernel_size=7,
205
+ last_kernel_size=3):
206
+
207
+ super().__init__()
208
+ mult = int(2 ** len(ratios))
209
+ model = [BufferConv1d(dimension,
210
+ mult * n_filters,
211
+ kernel_size,
212
+ bias=True)]
213
+ #UP
214
+ for i, ratio in enumerate(ratios):
215
+ model += [nn.ELU(),
216
+ BufferConvTranspose1d(mult * n_filters,
217
+ mult * n_filters // 2,
218
+ kernel_size=ratio * 2,
219
+ stride=ratio,
220
+ bias=True),
221
+ SEANetResnetBlock(mult * n_filters // 2)]
222
+ mult //= 2
223
+ # LAST
224
+ model += [
225
+ nn.ELU(),
226
+ BufferConv1d(
227
+ n_filters,
228
+ channels,
229
+ last_kernel_size,
230
+ bias=True
231
+ ),
232
+ ]
233
+ self.model = nn.Sequential(*model)
234
+
235
+ def forward(self, x):
236
+ return self.model(x)
237
+
238
+
239
+ class BufferConv1d(nn.Conv1d):
240
+ def __init__(self,
241
+ *args,
242
+ **kwargs):
243
+ super().__init__(*args,
244
+ **kwargs)
245
+ self.previous = None
246
+
247
+ def forward(self, x):
248
+ k = self.kernel_size[0]
249
+
250
+ if self.previous is not None:
251
+
252
+ x = torch.cat([self.previous, x], 2)
253
+
254
+ else: # If self.previous is None => Use zero pad
255
+
256
+ if k == 3:
257
+
258
+ p = (2, 0)
259
+ x = F.pad(x, p, mode='replicate', value=0.0) # skip connections SeaNetResBlk
260
+
261
+ elif k == 4: # ConvTrUpsample is the first conv encountered by decode replicate solves pulse
262
+
263
+ p = (3, 0)
264
+ x = F.pad(x, p, mode='replicate', value=0.0)
265
+
266
+ elif k == 7:
267
+
268
+ p = (6, 0)
269
+ x = F.pad(x, p, mode='replicate', value=0.0)
270
+
271
+ elif k == 16:
272
+
273
+ p = (2, 0)
274
+ x = F.pad(x, p, mode='replicate', value=0.0) # THis can be also constant w/o pulse occur
275
+
276
+ num_frames = int( (x.shape[2] - self.kernel_size[0]) / self.stride[0] ) + 1 # +1 is: k starts at left of x and doing (I-k)/s jumps
277
+ offset = num_frames * self.stride[0]
278
+ self.previous = x[..., offset:]
279
+ return super().forward(x)
280
+
281
+
282
+ class BufferConvTranspose1d(nn.ConvTranspose1d):
283
+ # kernel 5 actually has only 1 pixel for input (is scalar replicat. of 1 pix)
284
+ # https://distill.pub/2016/deconv-checkerboard/
285
+ def __init__(self,
286
+ *args,
287
+ **kwargs):
288
+ super().__init__(*args,
289
+ **kwargs)
290
+ self.partial = None
291
+
292
+ def forward(self, x):
293
+ out = super().forward(x)
294
+ OT = out.shape[2]
295
+ invalid_steps = self.kernel_size[0] - self.stride[0]
296
+ if self.partial is not None:
297
+ PT = self.partial.shape[-1]
298
+ if self.bias is not None:
299
+ out[..., :PT] += self.partial - self.bias[:, None]
300
+ else:
301
+ out[..., :PT] += self.partial # for ConvTrUpsample1d
302
+ invalid_steps = self.kernel_size[0] - self.stride[0]
303
+ self.partial = out[..., OT - invalid_steps :]
304
+ out = out[...,:OT - invalid_steps]
305
+ return out
306
+
307
+
308
+ class CodeBook(nn.Module):
309
+ def __init__(self, dim, codebook_size):
310
+ super().__init__()
311
+ self.register_buffer('_e', torch.zeros(codebook_size, dim))
312
+
313
+ def encode(self, x):
314
+ dist = torch.cdist(
315
+ x.transpose(1, 2), # [bs, time, 256]
316
+ self._e[None, :, :] # [1, 2048, 256]
317
+ )
318
+ codes = dist.argmin(2)
319
+ return codes
320
+
321
+ def decode(self, codes):
322
+ quantized = F.embedding(codes, self._e)
323
+ return quantized.transpose(1, 2) # [1, 256, time]
324
+
325
+
326
+ class SplitResidualVectorQuantizer(nn.Module):
327
+
328
+ def __init__(self,
329
+ n_q=None,
330
+ # https://huggingface.co/kyutai/moshiko-pytorch-bf16/blob/main/tokenizer-e351c8d8-checkpoint125.safetensors
331
+ ):
332
+ super().__init__()
333
+ self.in_proj_s = torch.nn.Conv1d(512, 256, 1, bias=False)
334
+ self.in_proj_a = torch.nn.Conv1d(512, 256, 1, bias=False)
335
+ self.out_proj_s = torch.nn.Conv1d(256, 512, 1, bias=False) # reused for all 31 aco.
336
+ self.out_proj_a = torch.nn.Conv1d(256, 512, 1, bias=False)
337
+ self.layers = nn.ModuleList([CodeBook(dim=256, codebook_size=2048) for _ in range(18)])
338
+ self._acoustic_books = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 17, 17, 17, 17] #list(range(1, n_q)) # n_q being use - exclude 0 here
339
+
340
+ def encode(self, x):
341
+ indices = self.layers[0].encode(self.in_proj_s(x)) # integers
342
+ all_indices = [ indices[:, None, :], ]
343
+ x = self.in_proj_a(x)
344
+ for _cb in self._acoustic_books:
345
+ indices = self.layers[_cb].encode(x)
346
+ x = x - self.layers[_cb].decode(indices)
347
+ all_indices.append(indices[:, None, :])
348
+ codes = torch.cat(all_indices, 1)
349
+ return codes
350
+
351
+ def decode(self, codes):
352
+ _s = self.layers[0].decode(codes[:, 0, :])
353
+ _a = torch.zeros([1, 1], device=codes.device)
354
+ for i, _cb in enumerate(self._acoustic_books):
355
+ _a = _a + self.layers[_cb].decode(codes[:, i+1, :])
356
+ return self.out_proj_s(_s) + self.out_proj_a(_a) # [bs, 512, time]
357
+
358
+
359
+ class VocAttention(nn.Module):
360
+
361
+ def __init__(self,
362
+ embed_dim):
363
+
364
+ super().__init__()
365
+ self.fused_proj = nn.Parameter(torch.zeros(embed_dim, embed_dim))
366
+
367
+ def forward(self, x):
368
+ '''bypass of streaming training'''
369
+ if x.shape[1] > 1:
370
+ x = x.mean(1, keepdims=True)
371
+ x = torch.matmul(x, self.fused_proj)
372
+ return x # FFN broadcasts to original x.shape[1]
373
+
374
+
375
+ class VocTransformerLayer(nn.Module):
376
+
377
+ def __init__(self, d_model=512, dim_feedforward=2048):
378
+ super().__init__()
379
+ self.self_attn = VocAttention(embed_dim=d_model)
380
+ self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
381
+ self.norm2 = nn.LayerNorm(d_model, eps=1e-5)
382
+ self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False)
383
+ self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False)
384
+
385
+ def forward(self, x):
386
+ x = x + self.self_attn(self.norm1(x))
387
+ return x + self.linear2(F.gelu(self.linear1(self.norm2(x))))
388
+
389
+
390
+ class VocTransformer(nn.Module):
391
+
392
+ def __init__(self,
393
+ n_layers=8):
394
+
395
+ super().__init__()
396
+ self.layers = nn.ModuleList(VocTransformerLayer() for _ in range(n_layers))
397
+
398
+ def forward(self, x):
399
+ x = x.transpose(1, 2)
400
+ for la in self.layers:
401
+ x = la(x)
402
+ return x.transpose(1, 2)
403
+
404
+ device = 'cpu' #'cuda:0'
405
+ model = Voc.from_pretrained('ivao0/voc').to(device)
406
+ true_audio = hf_hub_download(repo_id='ivao0/voc', filename='true.wav')
407
+ x = torch.from_numpy(x[None, None, :]).to(dtype=torch.float, device=device)
408
+ y = model.decode(model.encode(x)) # Notice if len < 1920 audio samples -> codes is torch.empty
409
+
410
+ soundfile.write('reconstruct.wav', y[0, 0, :].cpu().numpy(), 24000)
411
+
412
+ model._flush() # streaming buffers
413
+ y = model.decode(model.encode(x.repeat(6,1,1))) # switch batch siz
414
+ ```
415
+
config.json CHANGED
@@ -6,102 +6,7 @@
6
  "add_adapter": false,
7
  "apply_spec_augment": true,
8
  "architectures": [
9
- "Voc"
10
  ],
11
- "attention_dropout": 0.1,
12
- "bos_token_id": 1,
13
- "classifier_proj_size": 256,
14
- "codevector_dim": 256,
15
- "contrastive_logits_temperature": 0.1,
16
- "conv_bias": false,
17
- "conv_dim": [
18
- 512,
19
- 512,
20
- 512,
21
- 512,
22
- 512,
23
- 512,
24
- 512
25
- ],
26
- "conv_kernel": [
27
- 10,
28
- 3,
29
- 3,
30
- 3,
31
- 3,
32
- 2,
33
- 2
34
- ],
35
- "conv_stride": [
36
- 5,
37
- 2,
38
- 2,
39
- 2,
40
- 2,
41
- 2,
42
- 2
43
- ],
44
- "ctc_loss_reduction": "sum",
45
- "ctc_zero_infinity": false,
46
- "diversity_loss_weight": 0.1,
47
- "do_stable_layer_norm": false,
48
- "dtype": "float32",
49
- "eos_token_id": 2,
50
- "feat_extract_activation": "gelu",
51
- "feat_extract_norm": "group",
52
- "feat_proj_dropout": 0.0,
53
- "feat_quantizer_dropout": 0.0,
54
- "final_dropout": 0.1,
55
- "hidden_act": "gelu",
56
- "hidden_dropout": 0.1,
57
- "hidden_size": 768,
58
- "initializer_range": 0.02,
59
- "intermediate_size": 3072,
60
- "layer_norm_eps": 1e-05,
61
- "layerdrop": 0.1,
62
- "mask_feature_length": 10,
63
- "mask_feature_min_masks": 0,
64
- "mask_feature_prob": 0.0,
65
- "mask_time_length": 10,
66
- "mask_time_min_masks": 2,
67
- "mask_time_prob": 0.05,
68
- "model_type": "wav2vec2",
69
- "num_adapter_layers": 3,
70
- "num_attention_heads": 12,
71
- "num_codevector_groups": 2,
72
- "num_codevectors_per_group": 320,
73
- "num_conv_pos_embedding_groups": 16,
74
- "num_conv_pos_embeddings": 128,
75
- "num_feat_extract_layers": 7,
76
- "num_hidden_layers": 12,
77
- "num_negatives": 100,
78
- "output_hidden_size": 768,
79
- "pad_token_id": 0,
80
- "proj_codevector_dim": 256,
81
- "sample_rate": 24000,
82
- "tdnn_dilation": [
83
- 1,
84
- 2,
85
- 3,
86
- 1,
87
- 1
88
- ],
89
- "tdnn_dim": [
90
- 512,
91
- 512,
92
- 512,
93
- 512,
94
- 1500
95
- ],
96
- "tdnn_kernel": [
97
- 5,
98
- 3,
99
- 3,
100
- 1,
101
- 1
102
- ],
103
- "transformers_version": "4.57.0",
104
- "use_weighted_layer_sum": false,
105
- "vocab_size": 32,
106
- "xvector_output_dim": 512
107
  }
 
6
  "add_adapter": false,
7
  "apply_spec_augment": true,
8
  "architectures": [
9
+ "AvoS"
10
  ],
11
+ "attention_dropout": 0.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  }
true.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8e2099b05b8a61a17060c2d4a26f1e859b94ec63685a20a00a5e3991ec72a189
3
+ size 3840044
video.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b0b7400ccc97cfef03bfc7b239130e9ced382eb1cd9dfa900ef723d0634742c9
3
+ size 1686059