yueyulin commited on
Commit
e878aac
·
verified ·
1 Parent(s): c5ddd69

Upload folder using huggingface_hub

Browse files
asr_gradio/rwkv7-g1a-0.4b-20250905-ctx4096.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d852e99ef6c95726109660c64e7c51a8df30c53b0832a68645bfcd15253b3109
3
+ size 901776757
asr_gradio/rwkv7_0.1b_audio_lm_latents_280k/config.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "a_low_rank_dim": 64,
3
+ "architectures": [
4
+ "RWKV7ModelForLatentInputs"
5
+ ],
6
+ "attn": null,
7
+ "attn_mode": "chunk",
8
+ "bos_token_id": 0,
9
+ "decay_low_rank_dim": 64,
10
+ "eos_token_id": 0,
11
+ "fuse_cross_entropy": true,
12
+ "fuse_linear_cross_entropy": false,
13
+ "fuse_norm": false,
14
+ "gate_low_rank_dim": 128,
15
+ "head_dim": 64,
16
+ "hidden_act": "sqrelu",
17
+ "hidden_ratio": 4.0,
18
+ "hidden_size": 768,
19
+ "initializer_range": 0.006,
20
+ "intermediate_size": 3072,
21
+ "max_position_embeddings": 2048,
22
+ "model_type": "rwkv7",
23
+ "norm_bias": true,
24
+ "norm_eps": 1e-05,
25
+ "norm_first": true,
26
+ "num_heads": 32,
27
+ "num_hidden_layers": 12,
28
+ "tie_word_embeddings": false,
29
+ "torch_dtype": "float32",
30
+ "train_time_state": false,
31
+ "transformers_version": "4.52.4",
32
+ "use_cache": true,
33
+ "use_l2warp": true,
34
+ "v_low_rank_dim": 32,
35
+ "value_dim": [
36
+ 768,
37
+ 768,
38
+ 768,
39
+ 768,
40
+ 768,
41
+ 768,
42
+ 768,
43
+ 768,
44
+ 768,
45
+ 768,
46
+ 768,
47
+ 768
48
+ ],
49
+ "vocab_size": 10
50
+ }
asr_gradio/rwkv7_0.1b_audio_lm_latents_280k/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ecf58ad41179876c07d341991b5bbfb2796d2669ebcb76657d7befe7b139934
3
+ size 361523904
asr_gradio/rwkv7_0.1b_audio_lm_latents_280k/model_converted.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3dc29a8ee50e423522d8f0710f4942f7934a574889fd5b84b104b1bc2eb3139d
3
+ size 361668745
asr_gradio/rwkv7_0.1b_audio_lm_latents_280k/projector1.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d2d102133e65bb3da053d33c70d05083a43221b9e41ac7e37d759ddc98f909aa
3
+ size 3937149
asr_gradio/rwkv7_0.1b_audio_lm_latents_280k/projector2.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65febcb6484e4a5464d0df62fe088d4533e8dac2724e1c2fcabe59d60ac0b1b2
3
+ size 3151741
asr_gradio/utils/__pycache__/asr_inference_with_chatrwkv.cpython-311.pyc CHANGED
Binary files a/asr_gradio/utils/__pycache__/asr_inference_with_chatrwkv.cpython-311.pyc and b/asr_gradio/utils/__pycache__/asr_inference_with_chatrwkv.cpython-311.pyc differ
 
asr_gradio/utils/asr_inference_with_chatrwkv.py CHANGED
@@ -15,6 +15,9 @@ import numpy as np
15
  import click
16
  import time
17
  import copy
 
 
 
18
  @dataclass
19
  class AsrModels:
20
  audio_llm: RWKV
@@ -24,6 +27,7 @@ class AsrModels:
24
  project2_linear: torch.nn.Linear
25
  llm: RWKV
26
  tokenizer: TRIE_TOKENIZER
 
27
 
28
 
29
  def forward_one_with_embeds(model :RWKV,embeds:torch.Tensor,state:List[torch.Tensor]):
@@ -107,6 +111,9 @@ def load_asr_models(audio_lm_path, llm_path,whisper_path,tokenizer_path,device,d
107
  project2_linear = torch.nn.Linear(project2['weight'].shape[1], project2['weight'].shape[0])
108
  project2_linear.load_state_dict(project2)
109
  tokenizer = TRIE_TOKENIZER(tokenizer_path)
 
 
 
110
  return AsrModels(
111
  audio_llm=audio_llm,
112
  whisper_feature_extractor=whisper_feature_extractor,
@@ -115,6 +122,7 @@ def load_asr_models(audio_lm_path, llm_path,whisper_path,tokenizer_path,device,d
115
  project2_linear=project2_linear.to(device=device,dtype=dtype),
116
  llm=llm,
117
  tokenizer=tokenizer,
 
118
  )
119
 
120
  def calculate_perplexity(models, generated_tokens, dtype, device):
@@ -159,6 +167,46 @@ def calculate_perplexity(models, generated_tokens, dtype, device):
159
 
160
  return perplexity
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  def sample_logits(logits, temperature=1.0, top_p=0.85, top_k=0):
163
  if temperature == 0:
164
  temperature = 1.0
@@ -231,7 +279,7 @@ def extract_audio_latents(models, audio_file_path,dtype):
231
  return projected_latents,audio_valid_length
232
 
233
  @torch.inference_mode()
234
- def inference_asr(models, audio_path, language,dtype,device,resample_count = 1):
235
  if language == 'chinese':
236
  print(f'language: {language}')
237
  instruction = "User: 请将以下语音转写为中文。\n"
@@ -244,7 +292,7 @@ def inference_asr(models, audio_path, language,dtype,device,resample_count = 1):
244
  print(f'load audio from {audio_path}')
245
  audio_path = audio_path
246
  time_start = time.time()
247
- audio_latents,audio_valid_length = extract_audio_latents(models, audio_path,dtype)
248
  time_end = time.time()
249
  print(f'whisper time: {time_end - time_start}')
250
  time_start = time.time()
@@ -261,33 +309,47 @@ def inference_asr(models, audio_path, language,dtype,device,resample_count = 1):
261
  with torch.no_grad():
262
  audio_latents = F.layer_norm(audio_latents, (models.llm.n_embd,), weight=models.llm.z['blocks.0.ln0.weight'], bias=models.llm.z['blocks.0.ln0.bias'])#do the first layer norm for embeddings input
263
  whole_input_embeds = torch.cat([instruction_input_embeds, audio_latents, hints_input_embeds], dim=0)
264
- hidden_states,init_state = forward_seq_with_embeds(models.llm, whole_input_embeds, dtype, device, None, False)
265
  time_end = time.time()
266
  print(f'prefill time: {time_end - time_start}')
267
  with torch.no_grad():
268
  initial_logits = hidden_states @ models.llm.z['head.weight']
 
 
 
269
  scored_results = []
270
- for i in range(resample_count):
271
- next_token = sample_logits(initial_logits,top_k=10,top_p=0.6,temperature=0.6)
272
- results = []
273
- results.append(next_token)
274
- state = copy.deepcopy(init_state)
275
- while len(results) < 1024:
276
- logits,state = models.llm.forward([next_token], state)
277
- next_token = sample_logits(logits,top_k=10,top_p=0.6,temperature=0.6)
278
- results.append(next_token)
279
- if next_token == 0:
280
- break
281
-
282
- # 计算生成序列的perplexity
283
- print(f"计算生成序列的perplexity,序列长度: {len(results)}")
284
- perplexity = calculate_perplexity(models, results, dtype, device)
285
- print(f"生成序列的perplexity: {perplexity:.4f}")
286
- scored_results.append((results, perplexity))
287
  print(f'scored_results: {scored_results}')
288
  results, perplexity = min(scored_results, key=lambda x: x[1])
289
  return results[:-1], perplexity
290
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  @click.command()
292
  @click.option('--audio-lm-path', default="/home/yueyulin/models/rwkv7_0.1b_audio_lm_latents_1.5b_44k",
293
  help='音频语言模型路径')
@@ -325,13 +387,17 @@ def main(audio_lm_path, llm_path, whisper_path, audio_path, tokenizer_path, lang
325
  print(f'project1: {models.project1_linear}')
326
  print(f'project2: {models.project2_linear}')
327
  start_time = time.time()
328
- results, perplexity = inference_asr(models, audio_path, language, dtype, device, resample_count=3)
329
- print(f'results: {results}')
330
- print(f'decode results: {models.tokenizer.decode(results)}')
331
- print(f'perplexity: {perplexity:.4f}')
332
- end_time = time.time()
333
- print(f'time: {end_time - start_time}')
334
- return results, perplexity
 
 
 
 
335
 
336
  if __name__ == "__main__":
337
  main()
 
15
  import click
16
  import time
17
  import copy
18
+ from concurrent.futures import ThreadPoolExecutor
19
+ import threading
20
+
21
  @dataclass
22
  class AsrModels:
23
  audio_llm: RWKV
 
27
  project2_linear: torch.nn.Linear
28
  llm: RWKV
29
  tokenizer: TRIE_TOKENIZER
30
+ thread_pool: ThreadPoolExecutor
31
 
32
 
33
  def forward_one_with_embeds(model :RWKV,embeds:torch.Tensor,state:List[torch.Tensor]):
 
111
  project2_linear = torch.nn.Linear(project2['weight'].shape[1], project2['weight'].shape[0])
112
  project2_linear.load_state_dict(project2)
113
  tokenizer = TRIE_TOKENIZER(tokenizer_path)
114
+ # 创建常驻线程池,固定3个线程
115
+ thread_pool = ThreadPoolExecutor(thread_name_prefix="ASR-Inference")
116
+
117
  return AsrModels(
118
  audio_llm=audio_llm,
119
  whisper_feature_extractor=whisper_feature_extractor,
 
122
  project2_linear=project2_linear.to(device=device,dtype=dtype),
123
  llm=llm,
124
  tokenizer=tokenizer,
125
+ thread_pool=thread_pool,
126
  )
127
 
128
  def calculate_perplexity(models, generated_tokens, dtype, device):
 
167
 
168
  return perplexity
169
 
170
+ def single_inference_task(initial_logits, init_state, models, dtype, device, task_id):
171
+ """
172
+ 单个推理任务,用于并发执行
173
+
174
+ Args:
175
+ initial_logits: 初始logits
176
+ init_state: 初始状态
177
+ models: ASR模型集合
178
+ dtype: 数据类型
179
+ device: 设备
180
+ task_id: 任务ID
181
+
182
+ Returns:
183
+ tuple: (results, perplexity)
184
+ """
185
+ start_time = time.time()
186
+ print(f"任务 {task_id} 开始执行")
187
+
188
+ # 生成token序列
189
+ next_token = sample_logits(initial_logits, top_k=10, top_p=0.6, temperature=0.6)
190
+ results = []
191
+ results.append(next_token)
192
+ state = copy.deepcopy(init_state)
193
+
194
+ while len(results) < 1024:
195
+ logits, state = models.llm.forward([next_token], state)
196
+ next_token = sample_logits(logits, top_k=10, top_p=0.6, temperature=0.6)
197
+ results.append(next_token)
198
+ if next_token == 0:
199
+ break
200
+
201
+ # 计算生成序列的perplexity
202
+ print(f"任务 {task_id} 计算生成序列的perplexity,序列长度: {len(results)}")
203
+ perplexity = calculate_perplexity(models, results, dtype, device)
204
+ print(f"任务 {task_id} 生成序列的perplexity: {perplexity:.4f}")
205
+
206
+ end_time = time.time()
207
+ print(f"任务 {task_id} 执行时间: {end_time - start_time}")
208
+ return results, perplexity
209
+
210
  def sample_logits(logits, temperature=1.0, top_p=0.85, top_k=0):
211
  if temperature == 0:
212
  temperature = 1.0
 
279
  return projected_latents,audio_valid_length
280
 
281
  @torch.inference_mode()
282
+ def inference_asr(models, audio_path, language, dtype, device, resample_count=1):
283
  if language == 'chinese':
284
  print(f'language: {language}')
285
  instruction = "User: 请将以下语音转写为中文。\n"
 
292
  print(f'load audio from {audio_path}')
293
  audio_path = audio_path
294
  time_start = time.time()
295
+ audio_latents, audio_valid_length = extract_audio_latents(models, audio_path, dtype)
296
  time_end = time.time()
297
  print(f'whisper time: {time_end - time_start}')
298
  time_start = time.time()
 
309
  with torch.no_grad():
310
  audio_latents = F.layer_norm(audio_latents, (models.llm.n_embd,), weight=models.llm.z['blocks.0.ln0.weight'], bias=models.llm.z['blocks.0.ln0.bias'])#do the first layer norm for embeddings input
311
  whole_input_embeds = torch.cat([instruction_input_embeds, audio_latents, hints_input_embeds], dim=0)
312
+ hidden_states, init_state = forward_seq_with_embeds(models.llm, whole_input_embeds, dtype, device, None, False)
313
  time_end = time.time()
314
  print(f'prefill time: {time_end - time_start}')
315
  with torch.no_grad():
316
  initial_logits = hidden_states @ models.llm.z['head.weight']
317
+
318
+ # 使用models的常驻ThreadPoolExecutor进行并发推理
319
+ print(f"开始并发推理")
320
  scored_results = []
321
+
322
+ # 提交所有任务到常驻线程池
323
+ future_to_task = {
324
+ models.thread_pool.submit(single_inference_task, initial_logits, init_state, models, dtype, device, i): i
325
+ for i in range(resample_count)
326
+ }
327
+
328
+ # 收集所有结果
329
+ for future in future_to_task:
330
+ try:
331
+ results, perplexity = future.result()
332
+ scored_results.append((results, perplexity))
333
+ except Exception as exc:
334
+ task_id = future_to_task[future]
335
+ print(f'任务 {task_id} 产生异常: {exc}')
336
+
 
337
  print(f'scored_results: {scored_results}')
338
  results, perplexity = min(scored_results, key=lambda x: x[1])
339
  return results[:-1], perplexity
340
 
341
+ def cleanup_models(models):
342
+ """
343
+ 清理models资源,关闭线程池
344
+
345
+ Args:
346
+ models: AsrModels实例
347
+ """
348
+ if hasattr(models, 'thread_pool') and models.thread_pool:
349
+ print("正在关闭线程池...")
350
+ models.thread_pool.shutdown(wait=True)
351
+ print("线程池已关闭")
352
+
353
  @click.command()
354
  @click.option('--audio-lm-path', default="/home/yueyulin/models/rwkv7_0.1b_audio_lm_latents_1.5b_44k",
355
  help='音频语言模型路径')
 
387
  print(f'project1: {models.project1_linear}')
388
  print(f'project2: {models.project2_linear}')
389
  start_time = time.time()
390
+ try:
391
+ results, perplexity = inference_asr(models, audio_path, language, dtype, device, resample_count=3)
392
+ print(f'results: {results}')
393
+ print(f'decode results: {models.tokenizer.decode(results)}')
394
+ print(f'perplexity: {perplexity:.4f}')
395
+ end_time = time.time()
396
+ print(f'time: {end_time - start_time}')
397
+ return results, perplexity
398
+ finally:
399
+ # 确保在程序结束时清理线程池
400
+ cleanup_models(models)
401
 
402
  if __name__ == "__main__":
403
  main()