Spaces:
Runtime error
Runtime error
Autorestart space: Cuda Error (#15)
Browse files- Autorestart space (024283f748e35613e326c48c37d82ea8637f89d4)
Co-authored-by: Max Skobeev <[email protected]>
app.py
CHANGED
|
@@ -119,90 +119,98 @@ def build_html_error_message(error):
|
|
| 119 |
@GPU_DECORATOR
|
| 120 |
@torch.inference_mode()
|
| 121 |
def inference(req: ServeTTSRequest):
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
reference_audio=ref.audio,
|
| 129 |
-
enable_reference_audio=True,
|
| 130 |
-
)
|
| 131 |
-
for ref in refs
|
| 132 |
-
]
|
| 133 |
-
prompt_texts = [ref.text for ref in refs]
|
| 134 |
-
|
| 135 |
-
if req.seed is not None:
|
| 136 |
-
set_seed(req.seed)
|
| 137 |
-
logger.warning(f"set seed: {req.seed}")
|
| 138 |
-
|
| 139 |
-
# LLAMA Inference
|
| 140 |
-
request = dict(
|
| 141 |
-
device=decoder_model.device,
|
| 142 |
-
max_new_tokens=req.max_new_tokens,
|
| 143 |
-
text=(
|
| 144 |
-
req.text
|
| 145 |
-
if not req.normalize
|
| 146 |
-
else ChnNormedText(raw_text=req.text).normalize()
|
| 147 |
-
),
|
| 148 |
-
top_p=req.top_p,
|
| 149 |
-
repetition_penalty=req.repetition_penalty,
|
| 150 |
-
temperature=req.temperature,
|
| 151 |
-
compile=args.compile,
|
| 152 |
-
iterative_prompt=req.chunk_length > 0,
|
| 153 |
-
chunk_length=req.chunk_length,
|
| 154 |
-
max_length=4096,
|
| 155 |
-
prompt_tokens=prompt_tokens,
|
| 156 |
-
prompt_text=prompt_texts,
|
| 157 |
-
)
|
| 158 |
-
|
| 159 |
-
response_queue = queue.Queue()
|
| 160 |
-
llama_queue.put(
|
| 161 |
-
GenerateRequest(
|
| 162 |
-
request=request,
|
| 163 |
-
response_queue=response_queue,
|
| 164 |
-
)
|
| 165 |
-
)
|
| 166 |
-
|
| 167 |
-
segments = []
|
| 168 |
-
|
| 169 |
-
while True:
|
| 170 |
-
result: WrappedGenerateResponse = response_queue.get()
|
| 171 |
-
if result.status == "error":
|
| 172 |
-
yield None, None, build_html_error_message(result.response)
|
| 173 |
-
break
|
| 174 |
-
|
| 175 |
-
result: GenerateResponse = result.response
|
| 176 |
-
if result.action == "next":
|
| 177 |
-
break
|
| 178 |
-
|
| 179 |
-
with autocast_exclude_mps(
|
| 180 |
-
device_type=decoder_model.device.type, dtype=args.precision
|
| 181 |
-
):
|
| 182 |
-
fake_audios = decode_vq_tokens(
|
| 183 |
decoder_model=decoder_model,
|
| 184 |
-
|
|
|
|
| 185 |
)
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
)
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
n_audios = 4
|
| 208 |
|
|
|
|
| 119 |
@GPU_DECORATOR
|
| 120 |
@torch.inference_mode()
|
| 121 |
def inference(req: ServeTTSRequest):
|
| 122 |
+
try:
|
| 123 |
+
# Parse reference audio aka prompt
|
| 124 |
+
refs = req.references
|
| 125 |
+
|
| 126 |
+
prompt_tokens = [
|
| 127 |
+
encode_reference(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
decoder_model=decoder_model,
|
| 129 |
+
reference_audio=ref.audio,
|
| 130 |
+
enable_reference_audio=True,
|
| 131 |
)
|
| 132 |
+
for ref in refs
|
| 133 |
+
]
|
| 134 |
+
prompt_texts = [ref.text for ref in refs]
|
| 135 |
+
|
| 136 |
+
if req.seed is not None:
|
| 137 |
+
set_seed(req.seed)
|
| 138 |
+
logger.warning(f"set seed: {req.seed}")
|
| 139 |
+
|
| 140 |
+
# LLAMA Inference
|
| 141 |
+
request = dict(
|
| 142 |
+
device=decoder_model.device,
|
| 143 |
+
max_new_tokens=req.max_new_tokens,
|
| 144 |
+
text=(
|
| 145 |
+
req.text
|
| 146 |
+
if not req.normalize
|
| 147 |
+
else ChnNormedText(raw_text=req.text).normalize()
|
| 148 |
),
|
| 149 |
+
top_p=req.top_p,
|
| 150 |
+
repetition_penalty=req.repetition_penalty,
|
| 151 |
+
temperature=req.temperature,
|
| 152 |
+
compile=args.compile,
|
| 153 |
+
iterative_prompt=req.chunk_length > 0,
|
| 154 |
+
chunk_length=req.chunk_length,
|
| 155 |
+
max_length=4096,
|
| 156 |
+
prompt_tokens=prompt_tokens,
|
| 157 |
+
prompt_text=prompt_texts,
|
| 158 |
)
|
| 159 |
+
|
| 160 |
+
response_queue = queue.Queue()
|
| 161 |
+
llama_queue.put(
|
| 162 |
+
GenerateRequest(
|
| 163 |
+
request=request,
|
| 164 |
+
response_queue=response_queue,
|
| 165 |
+
)
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
segments = []
|
| 169 |
+
|
| 170 |
+
while True:
|
| 171 |
+
result: WrappedGenerateResponse = response_queue.get()
|
| 172 |
+
if result.status == "error":
|
| 173 |
+
yield None, None, build_html_error_message(result.response)
|
| 174 |
+
break
|
| 175 |
+
|
| 176 |
+
result: GenerateResponse = result.response
|
| 177 |
+
if result.action == "next":
|
| 178 |
+
break
|
| 179 |
+
|
| 180 |
+
with autocast_exclude_mps(
|
| 181 |
+
device_type=decoder_model.device.type, dtype=args.precision
|
| 182 |
+
):
|
| 183 |
+
fake_audios = decode_vq_tokens(
|
| 184 |
+
decoder_model=decoder_model,
|
| 185 |
+
codes=result.codes,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
fake_audios = fake_audios.float().cpu().numpy()
|
| 189 |
+
segments.append(fake_audios)
|
| 190 |
+
|
| 191 |
+
if len(segments) == 0:
|
| 192 |
+
return (
|
| 193 |
+
None,
|
| 194 |
+
None,
|
| 195 |
+
build_html_error_message(
|
| 196 |
+
i18n("No audio generated, please check the input text.")
|
| 197 |
+
),
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# No matter streaming or not, we need to return the final audio
|
| 201 |
+
audio = np.concatenate(segments, axis=0)
|
| 202 |
+
yield None, (decoder_model.spec_transform.sample_rate, audio), None
|
| 203 |
+
|
| 204 |
+
if torch.cuda.is_available():
|
| 205 |
+
torch.cuda.empty_cache()
|
| 206 |
+
gc.collect()
|
| 207 |
+
|
| 208 |
+
except Exception as e:
|
| 209 |
+
er = "CUDA error: device-side assert triggered"
|
| 210 |
+
if er in e:
|
| 211 |
+
app.close()
|
| 212 |
+
else:
|
| 213 |
+
raise Exception(e)
|
| 214 |
|
| 215 |
n_audios = 4
|
| 216 |
|