mrfakename commited on
Commit
5c81b55
·
1 Parent(s): 419883c

ZeroGPU compat

Browse files
Files changed (1) hide show
  1. app.py +76 -107
app.py CHANGED
@@ -1,10 +1,9 @@
1
- import multiprocessing as mp
2
  import torch
3
  import os
4
- from functools import partial
5
  import gradio as gr
6
  import traceback
7
- from huggingface_hub import hf_hub_download, snapshot_download
8
  from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav
9
 
10
 
@@ -27,121 +26,91 @@ def download_weights():
27
  return weights_dir
28
 
29
 
30
- def model_worker(input_queue, output_queue, device_id):
31
- device = None
32
- if device_id is not None:
33
- device = torch.device(f'cuda:{device_id}')
34
- infer_pipe = MegaTTS3DiTInfer(device=device)
35
 
36
- while True:
37
- task = input_queue.get()
38
- inp_audio_path, inp_text, infer_timestep, p_w, t_w = task
39
- try:
40
- convert_to_wav(inp_audio_path)
41
- wav_path = os.path.splitext(inp_audio_path)[0] + '.wav'
42
- cut_wav(wav_path, max_len=28)
43
- with open(wav_path, 'rb') as file:
44
- file_content = file.read()
45
- resource_context = infer_pipe.preprocess(file_content)
46
- wav_bytes = infer_pipe.forward(resource_context, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
47
- output_queue.put(wav_bytes)
48
- except Exception as e:
49
- traceback.print_exc()
50
- print(task, str(e))
51
- output_queue.put(None)
52
-
53
-
54
- def generate_speech(inp_audio, inp_text, infer_timestep, p_w, t_w, processes, input_queue, output_queue):
55
  if not inp_audio or not inp_text:
56
  gr.Warning("Please provide both reference audio and text to generate.")
57
  return None
58
 
59
- print("Generating speech with:", inp_audio, inp_text, infer_timestep, p_w, t_w)
60
- input_queue.put((inp_audio, inp_text, infer_timestep, p_w, t_w))
61
- res = output_queue.get()
62
- if res is not None:
63
- return res
64
- else:
65
- gr.Warning("Speech generation failed. Please try again.")
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  return None
67
 
68
 
69
- if __name__ == '__main__':
70
- # Download weights before starting
71
- download_weights()
72
-
73
- mp.set_start_method('spawn', force=True)
74
- mp_manager = mp.Manager()
75
-
76
- devices = os.environ.get('CUDA_VISIBLE_DEVICES', '')
77
- if devices != '':
78
- devices = os.environ.get('CUDA_VISIBLE_DEVICES', '').split(",")
79
- else:
80
- devices = None
81
 
82
- num_workers = 1
83
- input_queue = mp_manager.Queue()
84
- output_queue = mp_manager.Queue()
85
- processes = []
86
-
87
- print("Starting workers...")
88
- for i in range(num_workers):
89
- p = mp.Process(target=model_worker, args=(input_queue, output_queue, i % len(devices) if devices is not None else None))
90
- p.start()
91
- processes.append(p)
92
-
93
- with gr.Blocks(title="MegaTTS3 Voice Cloning") as demo:
94
- gr.Markdown("# MegaTTS3 Voice Cloning")
95
- gr.Markdown("Upload a reference audio clip and enter text to generate speech with the cloned voice.")
96
-
97
- with gr.Row():
98
- with gr.Column():
99
- reference_audio = gr.Audio(
100
- label="Reference Audio",
101
- type="filepath",
102
- sources=["upload", "microphone"]
103
  )
104
- text_input = gr.Textbox(
105
- label="Text to Generate",
106
- placeholder="Enter the text you want to synthesize...",
107
- lines=3
 
 
 
 
 
 
 
 
 
108
  )
109
-
110
- with gr.Accordion("Advanced Options", open=False):
111
- infer_timestep = gr.Number(
112
- label="Inference Timesteps",
113
- value=32,
114
- minimum=1,
115
- maximum=100,
116
- step=1
117
- )
118
- p_w = gr.Number(
119
- label="Intelligibility Weight",
120
- value=1.4,
121
- minimum=0.1,
122
- maximum=5.0,
123
- step=0.1
124
- )
125
- t_w = gr.Number(
126
- label="Similarity Weight",
127
- value=3.0,
128
- minimum=0.1,
129
- maximum=10.0,
130
- step=0.1
131
- )
132
-
133
- generate_btn = gr.Button("Generate Speech", variant="primary")
134
 
135
- with gr.Column():
136
- output_audio = gr.Audio(label="Generated Audio")
137
 
138
- generate_btn.click(
139
- fn=partial(generate_speech, processes=processes, input_queue=input_queue, output_queue=output_queue),
140
- inputs=[reference_audio, text_input, infer_timestep, p_w, t_w],
141
- outputs=[output_audio]
142
- )
143
-
144
- demo.launch(server_name='0.0.0.0', server_port=7860, debug=True)
145
 
146
- for p in processes:
147
- p.join()
 
 
 
 
 
 
 
1
+ import spaces
2
  import torch
3
  import os
 
4
  import gradio as gr
5
  import traceback
6
+ from huggingface_hub import snapshot_download
7
  from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav
8
 
9
 
 
26
  return weights_dir
27
 
28
 
29
+ # Download weights and initialize model
30
+ download_weights()
31
+ print("Initializing MegaTTS3 model...")
32
+ infer_pipe = MegaTTS3DiTInfer()
33
+ print("Model loaded successfully!")
34
 
35
+ @spaces.GPU
36
+ def generate_speech(inp_audio, inp_text, infer_timestep, p_w, t_w):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  if not inp_audio or not inp_text:
38
  gr.Warning("Please provide both reference audio and text to generate.")
39
  return None
40
 
41
+ try:
42
+ print(f"Generating speech with: {inp_text[:50]}...")
43
+
44
+ # Convert and prepare audio
45
+ convert_to_wav(inp_audio)
46
+ wav_path = os.path.splitext(inp_audio)[0] + '.wav'
47
+ cut_wav(wav_path, max_len=28)
48
+
49
+ # Read audio file
50
+ with open(wav_path, 'rb') as file:
51
+ file_content = file.read()
52
+
53
+ # Generate speech
54
+ resource_context = infer_pipe.preprocess(file_content)
55
+ wav_bytes = infer_pipe.forward(resource_context, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
56
+
57
+ return wav_bytes
58
+ except Exception as e:
59
+ traceback.print_exc()
60
+ gr.Warning(f"Speech generation failed: {str(e)}")
61
  return None
62
 
63
 
64
+ with gr.Blocks(title="MegaTTS3 Voice Cloning") as demo:
65
+ gr.Markdown("# MegaTTS3 Voice Cloning")
66
+ gr.Markdown("Upload a reference audio clip and enter text to generate speech with the cloned voice.")
 
 
 
 
 
 
 
 
 
67
 
68
+ with gr.Row():
69
+ with gr.Column():
70
+ reference_audio = gr.Audio(
71
+ label="Reference Audio",
72
+ type="filepath",
73
+ sources=["upload", "microphone"]
74
+ )
75
+ text_input = gr.Textbox(
76
+ label="Text to Generate",
77
+ placeholder="Enter the text you want to synthesize...",
78
+ lines=3
79
+ )
80
+
81
+ with gr.Accordion("Advanced Options", open=False):
82
+ infer_timestep = gr.Number(
83
+ label="Inference Timesteps",
84
+ value=32,
85
+ minimum=1,
86
+ maximum=100,
87
+ step=1
 
88
  )
89
+ p_w = gr.Number(
90
+ label="Intelligibility Weight",
91
+ value=1.4,
92
+ minimum=0.1,
93
+ maximum=5.0,
94
+ step=0.1
95
+ )
96
+ t_w = gr.Number(
97
+ label="Similarity Weight",
98
+ value=3.0,
99
+ minimum=0.1,
100
+ maximum=10.0,
101
+ step=0.1
102
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
+ generate_btn = gr.Button("Generate Speech", variant="primary")
 
105
 
106
+ with gr.Column():
107
+ output_audio = gr.Audio(label="Generated Audio")
 
 
 
 
 
108
 
109
+ generate_btn.click(
110
+ fn=generate_speech,
111
+ inputs=[reference_audio, text_input, infer_timestep, p_w, t_w],
112
+ outputs=[output_audio]
113
+ )
114
+
115
+ if __name__ == '__main__':
116
+ demo.launch(server_name='0.0.0.0', server_port=7860, debug=True)