datxy commited on
Commit
3fd6801
·
verified ·
1 Parent(s): 3f8297c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -65
app.py CHANGED
@@ -7,44 +7,41 @@ import torchaudio
7
  import gradio as gr
8
  import tempfile
9
 
10
- llasa_3b ='srinivasbilla/llasa-3b'
 
11
 
 
12
  tokenizer = AutoTokenizer.from_pretrained(llasa_3b)
13
 
14
  model = AutoModelForCausalLM.from_pretrained(
15
  llasa_3b,
16
  trust_remote_code=True,
17
- device_map='cuda',
18
  )
19
 
20
  model_path = "srinivasbilla/xcodec2"
21
-
22
  Codec_model = XCodec2Model.from_pretrained(model_path)
23
- Codec_model.eval().cuda()
24
 
25
  whisper_turbo_pipe = pipeline(
26
  "automatic-speech-recognition",
27
  model="openai/whisper-large-v3-turbo",
28
- torch_dtype=torch.float16,
29
- device='cuda',
30
  )
31
 
32
  def ids_to_speech_tokens(speech_ids):
33
-
34
- speech_tokens_str = []
35
- for speech_id in speech_ids:
36
- speech_tokens_str.append(f"<|s_{speech_id}|>")
37
- return speech_tokens_str
38
 
39
  def extract_speech_ids(speech_tokens_str):
40
-
41
  speech_ids = []
42
  for token_str in speech_tokens_str:
43
  if token_str.startswith('<|s_') and token_str.endswith('|>'):
44
- num_str = token_str[4:-2]
45
-
46
- num = int(num_str)
47
- speech_ids.append(num)
 
48
  else:
49
  print(f"Unexpected token: {token_str}")
50
  return speech_ids
@@ -54,20 +51,19 @@ def infer(sample_audio_path, target_text, progress=gr.Progress()):
54
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
55
  progress(0, 'Loading and trimming audio...')
56
  waveform, sample_rate = torchaudio.load(sample_audio_path)
57
- if len(waveform[0])/sample_rate > 15:
58
- gr.Warning("Trimming audio to first 15secs.")
59
- waveform = waveform[:, :sample_rate*15]
60
 
61
- # Check if the audio is stereo (i.e., has more than one channel)
 
 
 
62
  if waveform.size(0) > 1:
63
- # Convert stereo to mono by averaging the channels
64
  waveform_mono = torch.mean(waveform, dim=0, keepdim=True)
65
  else:
66
- # If already mono, just use the original waveform
67
  waveform_mono = waveform
68
 
69
  prompt_wav = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_mono)
70
  prompt_text = whisper_turbo_pipe(prompt_wav[0].numpy())['text'].strip()
 
71
  progress(0.5, 'Transcribed! Generating speech...')
72
 
73
  if len(target_text) == 0:
@@ -75,86 +71,65 @@ def infer(sample_audio_path, target_text, progress=gr.Progress()):
75
  elif len(target_text) > 300:
76
  gr.Warning("Text is too long. Please keep it under 300 characters.")
77
  target_text = target_text[:300]
78
-
79
  input_text = prompt_text + ' ' + target_text
80
 
81
- #TTS start!
82
  with torch.no_grad():
83
- # Encode the prompt wav
84
  vq_code_prompt = Codec_model.encode_code(input_waveform=prompt_wav)
85
-
86
- vq_code_prompt = vq_code_prompt[0,0,:]
87
- # Convert int 12345 to token <|s_12345|>
88
  speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt)
89
 
90
  formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
91
-
92
- # Tokenize the text and the speech prefix
93
  chat = [
94
  {"role": "user", "content": "Convert the text to speech:" + formatted_text},
95
  {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>" + ''.join(speech_ids_prefix)}
96
  ]
97
 
98
  input_ids = tokenizer.apply_chat_template(
99
- chat,
100
- tokenize=True,
101
- return_tensors='pt',
102
  continue_final_message=True
103
- )
104
- input_ids = input_ids.to('cuda')
105
  speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>')
106
 
107
- # Generate the speech autoregressively
108
  outputs = model.generate(
109
  input_ids,
110
- max_length=2048, # We trained our model with a max length of 2048
111
- eos_token_id= speech_end_id ,
112
  do_sample=True,
113
- top_p=1,
114
  temperature=0.8
115
  )
116
- # Extract the speech tokens
117
- generated_ids = outputs[0][input_ids.shape[1]-len(speech_ids_prefix):-1]
118
 
119
- speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
120
-
121
- # Convert token <|s_23456|> to int 23456
122
  speech_tokens = extract_speech_ids(speech_tokens)
123
 
124
- speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
125
-
126
- # Decode the speech tokens to speech waveform
127
- gen_wav = Codec_model.decode_code(speech_tokens)
128
-
129
- # if only need the generated part
130
- gen_wav = gen_wav[:,:,prompt_wav.shape[1]:]
131
 
132
  progress(1, 'Synthesized!')
133
-
134
- return (16000, gen_wav[0, 0, :].cpu().numpy())
135
 
136
  with gr.Blocks() as app_tts:
137
  gr.Markdown("# Zero Shot Voice Clone TTS")
138
  ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
139
  gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
140
-
141
  generate_btn = gr.Button("Synthesize", variant="primary")
142
-
143
  audio_output = gr.Audio(label="Synthesized Audio")
144
 
145
  generate_btn.click(
146
  infer,
147
- inputs=[
148
- ref_audio_input,
149
- gen_text_input,
150
- ],
151
  outputs=[audio_output],
152
  )
153
 
154
  with gr.Blocks() as app_credits:
155
  gr.Markdown("""
156
  # Credits
157
-
158
  * [zhenye234](https://github.com/zhenye234) for the original [repo](https://github.com/zhenye234/LLaSA_training)
159
  * [mrfakename](https://huggingface.co/mrfakename) for the [gradio demo code](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
160
  """)
@@ -163,15 +138,11 @@ with gr.Blocks() as app:
163
  gr.Markdown(
164
  """
165
  # llasa 3b TTS
166
-
167
  This is a local web UI for llasa 3b SOTA(imo) Zero Shot Voice Cloning and TTS model.
168
-
169
  The checkpoints support English and Chinese.
170
-
171
  If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt.
172
  """
173
  )
174
  gr.TabbedInterface([app_tts], ["TTS"])
175
 
176
-
177
- app.launch(ssr_mode=False)
 
7
  import gradio as gr
8
  import tempfile
9
 
10
+ # ✅ 自动选择设备(GPU 优先)
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
+ llasa_3b = 'srinivasbilla/llasa-3b'
14
  tokenizer = AutoTokenizer.from_pretrained(llasa_3b)
15
 
16
  model = AutoModelForCausalLM.from_pretrained(
17
  llasa_3b,
18
  trust_remote_code=True,
19
+ device_map=device,
20
  )
21
 
22
  model_path = "srinivasbilla/xcodec2"
 
23
  Codec_model = XCodec2Model.from_pretrained(model_path)
24
+ Codec_model.eval().to(device)
25
 
26
  whisper_turbo_pipe = pipeline(
27
  "automatic-speech-recognition",
28
  model="openai/whisper-large-v3-turbo",
29
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
30
+ device=0 if device == "cuda" else -1,
31
  )
32
 
33
  def ids_to_speech_tokens(speech_ids):
34
+ return [f"<|s_{sid}|>" for sid in speech_ids]
 
 
 
 
35
 
36
  def extract_speech_ids(speech_tokens_str):
 
37
  speech_ids = []
38
  for token_str in speech_tokens_str:
39
  if token_str.startswith('<|s_') and token_str.endswith('|>'):
40
+ try:
41
+ num = int(token_str[4:-2])
42
+ speech_ids.append(num)
43
+ except ValueError:
44
+ print(f"Invalid token format: {token_str}")
45
  else:
46
  print(f"Unexpected token: {token_str}")
47
  return speech_ids
 
51
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
52
  progress(0, 'Loading and trimming audio...')
53
  waveform, sample_rate = torchaudio.load(sample_audio_path)
 
 
 
54
 
55
+ if len(waveform[0]) / sample_rate > 15:
56
+ gr.Warning("Trimming audio to first 15 seconds.")
57
+ waveform = waveform[:, :sample_rate * 15]
58
+
59
  if waveform.size(0) > 1:
 
60
  waveform_mono = torch.mean(waveform, dim=0, keepdim=True)
61
  else:
 
62
  waveform_mono = waveform
63
 
64
  prompt_wav = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_mono)
65
  prompt_text = whisper_turbo_pipe(prompt_wav[0].numpy())['text'].strip()
66
+
67
  progress(0.5, 'Transcribed! Generating speech...')
68
 
69
  if len(target_text) == 0:
 
71
  elif len(target_text) > 300:
72
  gr.Warning("Text is too long. Please keep it under 300 characters.")
73
  target_text = target_text[:300]
74
+
75
  input_text = prompt_text + ' ' + target_text
76
 
 
77
  with torch.no_grad():
 
78
  vq_code_prompt = Codec_model.encode_code(input_waveform=prompt_wav)
79
+ vq_code_prompt = vq_code_prompt[0, 0, :]
 
 
80
  speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt)
81
 
82
  formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
 
 
83
  chat = [
84
  {"role": "user", "content": "Convert the text to speech:" + formatted_text},
85
  {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>" + ''.join(speech_ids_prefix)}
86
  ]
87
 
88
  input_ids = tokenizer.apply_chat_template(
89
+ chat,
90
+ tokenize=True,
91
+ return_tensors='pt',
92
  continue_final_message=True
93
+ ).to(device)
94
+
95
  speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>')
96
 
 
97
  outputs = model.generate(
98
  input_ids,
99
+ max_length=2048,
100
+ eos_token_id=speech_end_id,
101
  do_sample=True,
102
+ top_p=1,
103
  temperature=0.8
104
  )
 
 
105
 
106
+ generated_ids = outputs[0][input_ids.shape[1] - len(speech_ids_prefix):-1]
107
+ speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
 
108
  speech_tokens = extract_speech_ids(speech_tokens)
109
 
110
+ speech_tensor = torch.tensor(speech_tokens).to(device).unsqueeze(0).unsqueeze(0)
111
+ gen_wav = Codec_model.decode_code(speech_tensor)
112
+ gen_wav = gen_wav[:, :, prompt_wav.shape[1]:]
 
 
 
 
113
 
114
  progress(1, 'Synthesized!')
115
+ return (16000, gen_wav[0, 0, :].cpu().numpy())
 
116
 
117
  with gr.Blocks() as app_tts:
118
  gr.Markdown("# Zero Shot Voice Clone TTS")
119
  ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
120
  gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
 
121
  generate_btn = gr.Button("Synthesize", variant="primary")
 
122
  audio_output = gr.Audio(label="Synthesized Audio")
123
 
124
  generate_btn.click(
125
  infer,
126
+ inputs=[ref_audio_input, gen_text_input],
 
 
 
127
  outputs=[audio_output],
128
  )
129
 
130
  with gr.Blocks() as app_credits:
131
  gr.Markdown("""
132
  # Credits
 
133
  * [zhenye234](https://github.com/zhenye234) for the original [repo](https://github.com/zhenye234/LLaSA_training)
134
  * [mrfakename](https://huggingface.co/mrfakename) for the [gradio demo code](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
135
  """)
 
138
  gr.Markdown(
139
  """
140
  # llasa 3b TTS
 
141
  This is a local web UI for llasa 3b SOTA(imo) Zero Shot Voice Cloning and TTS model.
 
142
  The checkpoints support English and Chinese.
 
143
  If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt.
144
  """
145
  )
146
  gr.TabbedInterface([app_tts], ["TTS"])
147
 
148
+ app.launch(ssr_mode=False)