Upload folder using huggingface_hub
Browse files- asr_gradio/rwkv7-g1a-0.4b-20250905-ctx4096.pth +3 -0
- asr_gradio/rwkv7_0.1b_audio_lm_latents_280k/config.json +50 -0
- asr_gradio/rwkv7_0.1b_audio_lm_latents_280k/model.safetensors +3 -0
- asr_gradio/rwkv7_0.1b_audio_lm_latents_280k/model_converted.pth +3 -0
- asr_gradio/rwkv7_0.1b_audio_lm_latents_280k/projector1.pt +3 -0
- asr_gradio/rwkv7_0.1b_audio_lm_latents_280k/projector2.pt +3 -0
- asr_gradio/utils/__pycache__/asr_inference_with_chatrwkv.cpython-311.pyc +0 -0
- asr_gradio/utils/asr_inference_with_chatrwkv.py +93 -27
asr_gradio/rwkv7-g1a-0.4b-20250905-ctx4096.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d852e99ef6c95726109660c64e7c51a8df30c53b0832a68645bfcd15253b3109
|
| 3 |
+
size 901776757
|
asr_gradio/rwkv7_0.1b_audio_lm_latents_280k/config.json
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"a_low_rank_dim": 64,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"RWKV7ModelForLatentInputs"
|
| 5 |
+
],
|
| 6 |
+
"attn": null,
|
| 7 |
+
"attn_mode": "chunk",
|
| 8 |
+
"bos_token_id": 0,
|
| 9 |
+
"decay_low_rank_dim": 64,
|
| 10 |
+
"eos_token_id": 0,
|
| 11 |
+
"fuse_cross_entropy": true,
|
| 12 |
+
"fuse_linear_cross_entropy": false,
|
| 13 |
+
"fuse_norm": false,
|
| 14 |
+
"gate_low_rank_dim": 128,
|
| 15 |
+
"head_dim": 64,
|
| 16 |
+
"hidden_act": "sqrelu",
|
| 17 |
+
"hidden_ratio": 4.0,
|
| 18 |
+
"hidden_size": 768,
|
| 19 |
+
"initializer_range": 0.006,
|
| 20 |
+
"intermediate_size": 3072,
|
| 21 |
+
"max_position_embeddings": 2048,
|
| 22 |
+
"model_type": "rwkv7",
|
| 23 |
+
"norm_bias": true,
|
| 24 |
+
"norm_eps": 1e-05,
|
| 25 |
+
"norm_first": true,
|
| 26 |
+
"num_heads": 32,
|
| 27 |
+
"num_hidden_layers": 12,
|
| 28 |
+
"tie_word_embeddings": false,
|
| 29 |
+
"torch_dtype": "float32",
|
| 30 |
+
"train_time_state": false,
|
| 31 |
+
"transformers_version": "4.52.4",
|
| 32 |
+
"use_cache": true,
|
| 33 |
+
"use_l2warp": true,
|
| 34 |
+
"v_low_rank_dim": 32,
|
| 35 |
+
"value_dim": [
|
| 36 |
+
768,
|
| 37 |
+
768,
|
| 38 |
+
768,
|
| 39 |
+
768,
|
| 40 |
+
768,
|
| 41 |
+
768,
|
| 42 |
+
768,
|
| 43 |
+
768,
|
| 44 |
+
768,
|
| 45 |
+
768,
|
| 46 |
+
768,
|
| 47 |
+
768
|
| 48 |
+
],
|
| 49 |
+
"vocab_size": 10
|
| 50 |
+
}
|
asr_gradio/rwkv7_0.1b_audio_lm_latents_280k/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0ecf58ad41179876c07d341991b5bbfb2796d2669ebcb76657d7befe7b139934
|
| 3 |
+
size 361523904
|
asr_gradio/rwkv7_0.1b_audio_lm_latents_280k/model_converted.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3dc29a8ee50e423522d8f0710f4942f7934a574889fd5b84b104b1bc2eb3139d
|
| 3 |
+
size 361668745
|
asr_gradio/rwkv7_0.1b_audio_lm_latents_280k/projector1.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d2d102133e65bb3da053d33c70d05083a43221b9e41ac7e37d759ddc98f909aa
|
| 3 |
+
size 3937149
|
asr_gradio/rwkv7_0.1b_audio_lm_latents_280k/projector2.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:65febcb6484e4a5464d0df62fe088d4533e8dac2724e1c2fcabe59d60ac0b1b2
|
| 3 |
+
size 3151741
|
asr_gradio/utils/__pycache__/asr_inference_with_chatrwkv.cpython-311.pyc
CHANGED
|
Binary files a/asr_gradio/utils/__pycache__/asr_inference_with_chatrwkv.cpython-311.pyc and b/asr_gradio/utils/__pycache__/asr_inference_with_chatrwkv.cpython-311.pyc differ
|
|
|
asr_gradio/utils/asr_inference_with_chatrwkv.py
CHANGED
|
@@ -15,6 +15,9 @@ import numpy as np
|
|
| 15 |
import click
|
| 16 |
import time
|
| 17 |
import copy
|
|
|
|
|
|
|
|
|
|
| 18 |
@dataclass
|
| 19 |
class AsrModels:
|
| 20 |
audio_llm: RWKV
|
|
@@ -24,6 +27,7 @@ class AsrModels:
|
|
| 24 |
project2_linear: torch.nn.Linear
|
| 25 |
llm: RWKV
|
| 26 |
tokenizer: TRIE_TOKENIZER
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
def forward_one_with_embeds(model :RWKV,embeds:torch.Tensor,state:List[torch.Tensor]):
|
|
@@ -107,6 +111,9 @@ def load_asr_models(audio_lm_path, llm_path,whisper_path,tokenizer_path,device,d
|
|
| 107 |
project2_linear = torch.nn.Linear(project2['weight'].shape[1], project2['weight'].shape[0])
|
| 108 |
project2_linear.load_state_dict(project2)
|
| 109 |
tokenizer = TRIE_TOKENIZER(tokenizer_path)
|
|
|
|
|
|
|
|
|
|
| 110 |
return AsrModels(
|
| 111 |
audio_llm=audio_llm,
|
| 112 |
whisper_feature_extractor=whisper_feature_extractor,
|
|
@@ -115,6 +122,7 @@ def load_asr_models(audio_lm_path, llm_path,whisper_path,tokenizer_path,device,d
|
|
| 115 |
project2_linear=project2_linear.to(device=device,dtype=dtype),
|
| 116 |
llm=llm,
|
| 117 |
tokenizer=tokenizer,
|
|
|
|
| 118 |
)
|
| 119 |
|
| 120 |
def calculate_perplexity(models, generated_tokens, dtype, device):
|
|
@@ -159,6 +167,46 @@ def calculate_perplexity(models, generated_tokens, dtype, device):
|
|
| 159 |
|
| 160 |
return perplexity
|
| 161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
def sample_logits(logits, temperature=1.0, top_p=0.85, top_k=0):
|
| 163 |
if temperature == 0:
|
| 164 |
temperature = 1.0
|
|
@@ -231,7 +279,7 @@ def extract_audio_latents(models, audio_file_path,dtype):
|
|
| 231 |
return projected_latents,audio_valid_length
|
| 232 |
|
| 233 |
@torch.inference_mode()
|
| 234 |
-
def inference_asr(models, audio_path, language,dtype,device,resample_count
|
| 235 |
if language == 'chinese':
|
| 236 |
print(f'language: {language}')
|
| 237 |
instruction = "User: 请将以下语音转写为中文。\n"
|
|
@@ -244,7 +292,7 @@ def inference_asr(models, audio_path, language,dtype,device,resample_count = 1):
|
|
| 244 |
print(f'load audio from {audio_path}')
|
| 245 |
audio_path = audio_path
|
| 246 |
time_start = time.time()
|
| 247 |
-
audio_latents,audio_valid_length = extract_audio_latents(models, audio_path,dtype)
|
| 248 |
time_end = time.time()
|
| 249 |
print(f'whisper time: {time_end - time_start}')
|
| 250 |
time_start = time.time()
|
|
@@ -261,33 +309,47 @@ def inference_asr(models, audio_path, language,dtype,device,resample_count = 1):
|
|
| 261 |
with torch.no_grad():
|
| 262 |
audio_latents = F.layer_norm(audio_latents, (models.llm.n_embd,), weight=models.llm.z['blocks.0.ln0.weight'], bias=models.llm.z['blocks.0.ln0.bias'])#do the first layer norm for embeddings input
|
| 263 |
whole_input_embeds = torch.cat([instruction_input_embeds, audio_latents, hints_input_embeds], dim=0)
|
| 264 |
-
hidden_states,init_state = forward_seq_with_embeds(models.llm, whole_input_embeds, dtype, device, None, False)
|
| 265 |
time_end = time.time()
|
| 266 |
print(f'prefill time: {time_end - time_start}')
|
| 267 |
with torch.no_grad():
|
| 268 |
initial_logits = hidden_states @ models.llm.z['head.weight']
|
|
|
|
|
|
|
|
|
|
| 269 |
scored_results = []
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
scored_results.append((results, perplexity))
|
| 287 |
print(f'scored_results: {scored_results}')
|
| 288 |
results, perplexity = min(scored_results, key=lambda x: x[1])
|
| 289 |
return results[:-1], perplexity
|
| 290 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
@click.command()
|
| 292 |
@click.option('--audio-lm-path', default="/home/yueyulin/models/rwkv7_0.1b_audio_lm_latents_1.5b_44k",
|
| 293 |
help='音频语言模型路径')
|
|
@@ -325,13 +387,17 @@ def main(audio_lm_path, llm_path, whisper_path, audio_path, tokenizer_path, lang
|
|
| 325 |
print(f'project1: {models.project1_linear}')
|
| 326 |
print(f'project2: {models.project2_linear}')
|
| 327 |
start_time = time.time()
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
|
| 336 |
if __name__ == "__main__":
|
| 337 |
main()
|
|
|
|
| 15 |
import click
|
| 16 |
import time
|
| 17 |
import copy
|
| 18 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 19 |
+
import threading
|
| 20 |
+
|
| 21 |
@dataclass
|
| 22 |
class AsrModels:
|
| 23 |
audio_llm: RWKV
|
|
|
|
| 27 |
project2_linear: torch.nn.Linear
|
| 28 |
llm: RWKV
|
| 29 |
tokenizer: TRIE_TOKENIZER
|
| 30 |
+
thread_pool: ThreadPoolExecutor
|
| 31 |
|
| 32 |
|
| 33 |
def forward_one_with_embeds(model :RWKV,embeds:torch.Tensor,state:List[torch.Tensor]):
|
|
|
|
| 111 |
project2_linear = torch.nn.Linear(project2['weight'].shape[1], project2['weight'].shape[0])
|
| 112 |
project2_linear.load_state_dict(project2)
|
| 113 |
tokenizer = TRIE_TOKENIZER(tokenizer_path)
|
| 114 |
+
# 创建常驻线程池,固定3个线程
|
| 115 |
+
thread_pool = ThreadPoolExecutor(thread_name_prefix="ASR-Inference")
|
| 116 |
+
|
| 117 |
return AsrModels(
|
| 118 |
audio_llm=audio_llm,
|
| 119 |
whisper_feature_extractor=whisper_feature_extractor,
|
|
|
|
| 122 |
project2_linear=project2_linear.to(device=device,dtype=dtype),
|
| 123 |
llm=llm,
|
| 124 |
tokenizer=tokenizer,
|
| 125 |
+
thread_pool=thread_pool,
|
| 126 |
)
|
| 127 |
|
| 128 |
def calculate_perplexity(models, generated_tokens, dtype, device):
|
|
|
|
| 167 |
|
| 168 |
return perplexity
|
| 169 |
|
| 170 |
+
def single_inference_task(initial_logits, init_state, models, dtype, device, task_id):
|
| 171 |
+
"""
|
| 172 |
+
单个推理任务,用于并发执行
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
initial_logits: 初始logits
|
| 176 |
+
init_state: 初始状态
|
| 177 |
+
models: ASR模型集合
|
| 178 |
+
dtype: 数据类型
|
| 179 |
+
device: 设备
|
| 180 |
+
task_id: 任务ID
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
tuple: (results, perplexity)
|
| 184 |
+
"""
|
| 185 |
+
start_time = time.time()
|
| 186 |
+
print(f"任务 {task_id} 开始执行")
|
| 187 |
+
|
| 188 |
+
# 生成token序列
|
| 189 |
+
next_token = sample_logits(initial_logits, top_k=10, top_p=0.6, temperature=0.6)
|
| 190 |
+
results = []
|
| 191 |
+
results.append(next_token)
|
| 192 |
+
state = copy.deepcopy(init_state)
|
| 193 |
+
|
| 194 |
+
while len(results) < 1024:
|
| 195 |
+
logits, state = models.llm.forward([next_token], state)
|
| 196 |
+
next_token = sample_logits(logits, top_k=10, top_p=0.6, temperature=0.6)
|
| 197 |
+
results.append(next_token)
|
| 198 |
+
if next_token == 0:
|
| 199 |
+
break
|
| 200 |
+
|
| 201 |
+
# 计算生成序列的perplexity
|
| 202 |
+
print(f"任务 {task_id} 计算生成序列的perplexity,序列长度: {len(results)}")
|
| 203 |
+
perplexity = calculate_perplexity(models, results, dtype, device)
|
| 204 |
+
print(f"任务 {task_id} 生成序列的perplexity: {perplexity:.4f}")
|
| 205 |
+
|
| 206 |
+
end_time = time.time()
|
| 207 |
+
print(f"任务 {task_id} 执行时间: {end_time - start_time}")
|
| 208 |
+
return results, perplexity
|
| 209 |
+
|
| 210 |
def sample_logits(logits, temperature=1.0, top_p=0.85, top_k=0):
|
| 211 |
if temperature == 0:
|
| 212 |
temperature = 1.0
|
|
|
|
| 279 |
return projected_latents,audio_valid_length
|
| 280 |
|
| 281 |
@torch.inference_mode()
|
| 282 |
+
def inference_asr(models, audio_path, language, dtype, device, resample_count=1):
|
| 283 |
if language == 'chinese':
|
| 284 |
print(f'language: {language}')
|
| 285 |
instruction = "User: 请将以下语音转写为中文。\n"
|
|
|
|
| 292 |
print(f'load audio from {audio_path}')
|
| 293 |
audio_path = audio_path
|
| 294 |
time_start = time.time()
|
| 295 |
+
audio_latents, audio_valid_length = extract_audio_latents(models, audio_path, dtype)
|
| 296 |
time_end = time.time()
|
| 297 |
print(f'whisper time: {time_end - time_start}')
|
| 298 |
time_start = time.time()
|
|
|
|
| 309 |
with torch.no_grad():
|
| 310 |
audio_latents = F.layer_norm(audio_latents, (models.llm.n_embd,), weight=models.llm.z['blocks.0.ln0.weight'], bias=models.llm.z['blocks.0.ln0.bias'])#do the first layer norm for embeddings input
|
| 311 |
whole_input_embeds = torch.cat([instruction_input_embeds, audio_latents, hints_input_embeds], dim=0)
|
| 312 |
+
hidden_states, init_state = forward_seq_with_embeds(models.llm, whole_input_embeds, dtype, device, None, False)
|
| 313 |
time_end = time.time()
|
| 314 |
print(f'prefill time: {time_end - time_start}')
|
| 315 |
with torch.no_grad():
|
| 316 |
initial_logits = hidden_states @ models.llm.z['head.weight']
|
| 317 |
+
|
| 318 |
+
# 使用models的常驻ThreadPoolExecutor进行并发推理
|
| 319 |
+
print(f"开始并发推理")
|
| 320 |
scored_results = []
|
| 321 |
+
|
| 322 |
+
# 提交所有任务到常驻线程池
|
| 323 |
+
future_to_task = {
|
| 324 |
+
models.thread_pool.submit(single_inference_task, initial_logits, init_state, models, dtype, device, i): i
|
| 325 |
+
for i in range(resample_count)
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
# 收集所有结果
|
| 329 |
+
for future in future_to_task:
|
| 330 |
+
try:
|
| 331 |
+
results, perplexity = future.result()
|
| 332 |
+
scored_results.append((results, perplexity))
|
| 333 |
+
except Exception as exc:
|
| 334 |
+
task_id = future_to_task[future]
|
| 335 |
+
print(f'任务 {task_id} 产生异常: {exc}')
|
| 336 |
+
|
|
|
|
| 337 |
print(f'scored_results: {scored_results}')
|
| 338 |
results, perplexity = min(scored_results, key=lambda x: x[1])
|
| 339 |
return results[:-1], perplexity
|
| 340 |
|
| 341 |
+
def cleanup_models(models):
|
| 342 |
+
"""
|
| 343 |
+
清理models资源,关闭线程池
|
| 344 |
+
|
| 345 |
+
Args:
|
| 346 |
+
models: AsrModels实例
|
| 347 |
+
"""
|
| 348 |
+
if hasattr(models, 'thread_pool') and models.thread_pool:
|
| 349 |
+
print("正在关闭线程池...")
|
| 350 |
+
models.thread_pool.shutdown(wait=True)
|
| 351 |
+
print("线程池已关闭")
|
| 352 |
+
|
| 353 |
@click.command()
|
| 354 |
@click.option('--audio-lm-path', default="/home/yueyulin/models/rwkv7_0.1b_audio_lm_latents_1.5b_44k",
|
| 355 |
help='音频语言模型路径')
|
|
|
|
| 387 |
print(f'project1: {models.project1_linear}')
|
| 388 |
print(f'project2: {models.project2_linear}')
|
| 389 |
start_time = time.time()
|
| 390 |
+
try:
|
| 391 |
+
results, perplexity = inference_asr(models, audio_path, language, dtype, device, resample_count=3)
|
| 392 |
+
print(f'results: {results}')
|
| 393 |
+
print(f'decode results: {models.tokenizer.decode(results)}')
|
| 394 |
+
print(f'perplexity: {perplexity:.4f}')
|
| 395 |
+
end_time = time.time()
|
| 396 |
+
print(f'time: {end_time - start_time}')
|
| 397 |
+
return results, perplexity
|
| 398 |
+
finally:
|
| 399 |
+
# 确保在程序结束时清理线程池
|
| 400 |
+
cleanup_models(models)
|
| 401 |
|
| 402 |
if __name__ == "__main__":
|
| 403 |
main()
|