Commit
·
a0865e1
1
Parent(s):
96d5a10
(wip)add gpu gag
Browse files
cosyvoice/cli/cosyvoice.py
CHANGED
|
@@ -12,6 +12,7 @@
|
|
| 12 |
# See the License for the specific language governing permissions and
|
| 13 |
# limitations under the License.
|
| 14 |
import os
|
|
|
|
| 15 |
import time
|
| 16 |
from tqdm import tqdm
|
| 17 |
from hyperpyyaml import load_hyperpyyaml
|
|
@@ -58,7 +59,7 @@ class CosyVoice:
|
|
| 58 |
def list_avaliable_spks(self):
|
| 59 |
spks = list(self.frontend.spk2info.keys())
|
| 60 |
return spks
|
| 61 |
-
|
| 62 |
def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0):
|
| 63 |
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
|
| 64 |
model_input = self.frontend.frontend_sft(i, spk_id)
|
|
@@ -70,6 +71,7 @@ class CosyVoice:
|
|
| 70 |
yield model_output
|
| 71 |
start_time = time.time()
|
| 72 |
|
|
|
|
| 73 |
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0):
|
| 74 |
prompt_text = self.frontend.text_normalize(prompt_text, split=False)
|
| 75 |
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
|
|
|
|
| 12 |
# See the License for the specific language governing permissions and
|
| 13 |
# limitations under the License.
|
| 14 |
import os
|
| 15 |
+
import spaces
|
| 16 |
import time
|
| 17 |
from tqdm import tqdm
|
| 18 |
from hyperpyyaml import load_hyperpyyaml
|
|
|
|
| 59 |
def list_avaliable_spks(self):
|
| 60 |
spks = list(self.frontend.spk2info.keys())
|
| 61 |
return spks
|
| 62 |
+
@spaces.GPU
|
| 63 |
def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0):
|
| 64 |
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
|
| 65 |
model_input = self.frontend.frontend_sft(i, spk_id)
|
|
|
|
| 71 |
yield model_output
|
| 72 |
start_time = time.time()
|
| 73 |
|
| 74 |
+
@spaces.GPU
|
| 75 |
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0):
|
| 76 |
prompt_text = self.frontend.text_normalize(prompt_text, split=False)
|
| 77 |
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
|