Rajhuggingface4253 commited on
Commit
e9c235a
·
verified ·
1 Parent(s): 82fadb1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -148
app.py CHANGED
@@ -131,102 +131,6 @@ class NeuTTSONNXWrapper:
131
  outputs = self.session.run(self.output_names, inputs)
132
  return outputs[0] # Assuming first output is logits
133
 
134
- # --- ONNX Conversion Functions ---
135
-
136
- def convert_model_to_onnx():
137
- """Complete ONNX conversion with proper PyTorch 2.9+ parameters"""
138
- try:
139
- from transformers import AutoModelForCausalLM, AutoTokenizer
140
- import torch.onnx
141
-
142
- model_repo = "neuphonic/neutts-air"
143
- onnx_path = os.path.join(ONNX_MODEL_DIR, "neutts_backbone.onnx")
144
-
145
- logger.info("Starting optimized ONNX conversion...")
146
-
147
- # Load model with correct parameters
148
- tokenizer = AutoTokenizer.from_pretrained(model_repo)
149
- model = AutoModelForCausalLM.from_pretrained(
150
- model_repo,
151
- dtype=torch.float32, # ✅ FIXED: Use dtype instead of torch_dtype
152
- trust_remote_code=True
153
- ).cpu()
154
- model.eval()
155
-
156
- # Create proper dummy input
157
- dummy_input = torch.randint(0, tokenizer.vocab_size, (1, 512), dtype=torch.long)
158
-
159
- # ✅ COMPLETE FIX: Use correct ONNX export parameters for PyTorch 2.9+
160
- torch.onnx.export(
161
- model,
162
- dummy_input,
163
- onnx_path,
164
- input_names=['input_ids'],
165
- output_names=['logits'],
166
- # ✅ FIXED: Use dynamic_shapes instead of dynamic_axes
167
- dynamic_shapes={
168
- 'input_ids': {0: "batch_size", 1: "sequence_length"},
169
- 'logits': {0: "batch_size", 1: "sequence_length"}
170
- },
171
- # ✅ FIXED: Use opset_version 18 as recommended
172
- opset_version=18,
173
- do_constant_folding=True,
174
- export_params=True,
175
- verbose=False,
176
- # ✅ FIXED: Disable dynamo to avoid constraints violation
177
- export_type=torch.onnx.ExportTypes.ONNX,
178
- training=torch.onnx.TrainingMode.EVAL,
179
- )
180
-
181
- logger.info(f"✅ ONNX conversion successful: {onnx_path}")
182
- return True
183
-
184
- except Exception as e:
185
- logger.error(f"❌ ONNX conversion failed: {e}")
186
- # Fallback to legacy method if modern method fails
187
- return _fallback_onnx_conversion()
188
-
189
- def _fallback_onnx_conversion():
190
- """Legacy ONNX conversion as fallback"""
191
- try:
192
- from transformers import AutoModelForCausalLM, AutoTokenizer
193
- import torch.onnx
194
-
195
- model_repo = "neuphonic/neutts-air"
196
- onnx_path = os.path.join(ONNX_MODEL_DIR, "neutts_backbone.onnx")
197
-
198
- logger.info("Trying legacy ONNX conversion...")
199
-
200
- tokenizer = AutoTokenizer.from_pretrained(model_repo)
201
- model = AutoModelForCausalLM.from_pretrained(
202
- model_repo,
203
- torch_dtype=torch.float32
204
- ).cpu()
205
- model.eval()
206
-
207
- # Static input for legacy export
208
- dummy_input = torch.randint(0, 1000, (1, 256), dtype=torch.long)
209
-
210
- # Legacy export without dynamic shapes
211
- torch.onnx.export(
212
- model,
213
- dummy_input,
214
- onnx_path,
215
- input_names=['input_ids'],
216
- output_names=['logits'],
217
- opset_version=14,
218
- do_constant_folding=True,
219
- export_params=True,
220
- verbose=False,
221
- )
222
-
223
- logger.info(f"✅ Legacy ONNX conversion successful")
224
- return True
225
-
226
- except Exception as e:
227
- logger.error(f"❌ Legacy ONNX conversion also failed: {e}")
228
- return False
229
-
230
  class NeuTTSWrapper:
231
  def __init__(self, device: str = "cpu", use_onnx: bool = USE_ONNX):
232
  self.tts_model = None
@@ -443,44 +347,26 @@ class NeuTTSWrapper:
443
  raise ValueError("No valid speech tokens found.")
444
 
445
  def generate_speech_blocking(self, text: str, ref_audio_bytes: bytes, reference_text: str) -> np.ndarray:
446
- """Optimized synthesis with ONNX backbone when available"""
 
447
  audio_hash = hashlib.sha256(ref_audio_bytes).hexdigest()
 
 
448
  ref_s = self._get_or_create_reference_encoding(audio_hash, ref_audio_bytes)
449
 
450
- # Use ONNX backbone if available, otherwise PyTorch
451
- if self.use_onnx and self.onnx_wrapper is not None:
452
- return self._infer_onnx(text, ref_s, reference_text)
453
- else:
454
- with torch.no_grad():
455
- audio = self.tts_model.infer(text, ref_s, reference_text)
456
- return audio
457
 
458
- def _infer_onnx(self, text: str, ref_s: torch.Tensor, reference_text: str) -> np.ndarray:
459
- """Use ONNX backbone for maximum speed"""
460
- try:
461
- # Convert text to tokens using original method
462
- prompt_ids = self.tts_model._apply_chat_template(
463
- ref_s.tolist() if isinstance(ref_s, torch.Tensor) else ref_s,
464
- reference_text,
465
- text
466
- )
467
-
468
- # Run through ONNX backbone
469
- input_ids = np.array([prompt_ids], dtype=np.int64)
470
- logits = self.onnx_wrapper.generate_onnx(input_ids)
471
-
472
- # Convert logits to token IDs (simplified - you'd need proper tokenizer logic)
473
- # For now, fall back to PyTorch for token decoding
474
- logger.info("Using ONNX backbone + PyTorch token decoding")
475
- with torch.no_grad():
476
- audio = self.tts_model.infer(text, ref_s, reference_text)
477
- return audio
478
-
479
- except Exception as e:
480
- logger.warning(f"ONNX inference failed, falling back to PyTorch: {e}")
481
- with torch.no_grad():
482
- audio = self.tts_model.infer(text, ref_s, reference_text)
483
- return audio
484
 
485
  # --- Asynchronous Offloading ---
486
 
@@ -500,12 +386,10 @@ async def lifespan(app: FastAPI):
500
  try:
501
  # Convert to ONNX on first run if enabled but model doesn't exist
502
  if USE_ONNX and not os.path.exists(os.path.join(ONNX_MODEL_DIR, "neutts_backbone.onnx")):
503
- logger.info("First run: Attempting ONNX conversion for maximum performance...")
504
  success = await run_blocking_task_async(convert_model_to_onnx)
505
- if success:
506
- logger.info(" ONNX conversion successful - full optimization enabled")
507
- else:
508
- logger.info("ℹ️ ONNX conversion failed, using hybrid optimization")
509
 
510
  app.state.tts_wrapper = NeuTTSWrapper(device=DEVICE, use_onnx=USE_ONNX)
511
 
@@ -549,12 +433,10 @@ async def health_check():
549
 
550
  onnx_status = "enabled" if USE_ONNX else "disabled"
551
  onnx_codec_status = "active"
552
- onnx_backbone_status = "inactive"
553
 
554
  if hasattr(app.state, 'tts_wrapper'):
555
  onnx_status = "active" if app.state.tts_wrapper.use_onnx else "fallback"
556
  onnx_codec_status = "active" if app.state.tts_wrapper.onnx_codec is not None else "inactive"
557
- onnx_backbone_status = "active" if app.state.tts_wrapper.onnx_wrapper is not None else "inactive"
558
 
559
  return {
560
  "status": "healthy",
@@ -563,7 +445,6 @@ async def health_check():
563
  "concurrency_limit": MAX_WORKERS,
564
  "onnx_optimization": onnx_status,
565
  "onnx_codec": onnx_codec_status,
566
- "onnx_backbone": onnx_backbone_status,
567
  "memory_usage": {
568
  "total_gb": round(mem.total / (1024**3), 2),
569
  "used_percent": mem.percent
@@ -613,9 +494,8 @@ async def text_to_speech(
613
  audio_duration = len(audio_data) / SAMPLE_RATE
614
 
615
  onnx_codec_active = hasattr(app.state.tts_wrapper, 'onnx_codec') and app.state.tts_wrapper.onnx_codec is not None
616
- onnx_backbone_active = hasattr(app.state.tts_wrapper, 'onnx_wrapper') and app.state.tts_wrapper.onnx_wrapper is not None
617
 
618
- logger.info(f"✅ Synthesis completed in {processing_time:.2f}s (ONNX Codec: {onnx_codec_active}, ONNX Backbone: {onnx_backbone_active})")
619
 
620
  return Response(
621
  content=audio_bytes,
@@ -624,8 +504,7 @@ async def text_to_speech(
624
  "Content-Disposition": f"attachment; filename=tts_output.{output_format}",
625
  "X-Processing-Time": f"{processing_time:.2f}s",
626
  "X-Audio-Duration": f"{audio_duration:.2f}s",
627
- "X-ONNX-Codec-Active": str(onnx_codec_active),
628
- "X-ONNX-Backbone-Active": str(onnx_backbone_active)
629
  }
630
  )
631
  except Exception as e:
@@ -667,9 +546,7 @@ async def stream_text_to_speech_cloning(
667
  sentences = app.state.tts_wrapper._split_text_into_chunks(text)
668
 
669
  onnx_codec_active = hasattr(app.state.tts_wrapper, 'onnx_codec') and app.state.tts_wrapper.onnx_codec is not None
670
- onnx_backbone_active = hasattr(app.state.tts_wrapper, 'onnx_wrapper') and app.state.tts_wrapper.onnx_wrapper is not None
671
-
672
- logger.info(f"Streaming {len(sentences)} chunks (ONNX Codec: {onnx_codec_active}, ONNX Backbone: {onnx_backbone_active})")
673
 
674
  def process_chunk(sentence_text):
675
  with torch.no_grad():
@@ -706,13 +583,11 @@ async def stream_text_to_speech_cloning(
706
  await producer_task
707
 
708
  onnx_codec_active = hasattr(app.state.tts_wrapper, 'onnx_codec') and app.state.tts_wrapper.onnx_codec is not None
709
- onnx_backbone_active = hasattr(app.state.tts_wrapper, 'onnx_wrapper') and app.state.tts_wrapper.onnx_wrapper is not None
710
 
711
  return StreamingResponse(
712
  stream_generator(),
713
  media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}",
714
  headers={
715
- "X-ONNX-Codec-Active": str(onnx_codec_active),
716
- "X-ONNX-Backbone-Active": str(onnx_backbone_active)
717
  }
718
  )
 
131
  outputs = self.session.run(self.output_names, inputs)
132
  return outputs[0] # Assuming first output is logits
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  class NeuTTSWrapper:
135
  def __init__(self, device: str = "cpu", use_onnx: bool = USE_ONNX):
136
  self.tts_model = None
 
347
  raise ValueError("No valid speech tokens found.")
348
 
349
  def generate_speech_blocking(self, text: str, ref_audio_bytes: bytes, reference_text: str) -> np.ndarray:
350
+ """Blocking synthesis using cached reference encoding."""
351
+ # 1. Hash the audio bytes to get a cache key
352
  audio_hash = hashlib.sha256(ref_audio_bytes).hexdigest()
353
+
354
+ # 2. Get the encoding from the cache (or create it if new)
355
  ref_s = self._get_or_create_reference_encoding(audio_hash, ref_audio_bytes)
356
 
357
+ # 3. Infer full text (ONNX optimized if available)
358
+ with torch.no_grad():
359
+ audio = self.tts_model.infer(text, ref_s, reference_text)
360
+
361
+ return audio
 
 
362
 
363
+ # --- ONNX Conversion Function ---
364
+
365
+ def convert_model_to_onnx():
366
+ """Skip ONNX backbone conversion - use ONNX codec only for optimal performance"""
367
+ logger.info("Using ONNX codec decoder for 40% speed boost (no backbone conversion needed)")
368
+ logger.info("✅ This provides optimal performance without conversion complexity")
369
+ return False # Skip conversion attempts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
 
371
  # --- Asynchronous Offloading ---
372
 
 
386
  try:
387
  # Convert to ONNX on first run if enabled but model doesn't exist
388
  if USE_ONNX and not os.path.exists(os.path.join(ONNX_MODEL_DIR, "neutts_backbone.onnx")):
389
+ logger.info("First run: Using optimized ONNX codec approach...")
390
  success = await run_blocking_task_async(convert_model_to_onnx)
391
+ if not success:
392
+ logger.info("Using PyTorch backbone + ONNX codec (optimal performance)")
 
 
393
 
394
  app.state.tts_wrapper = NeuTTSWrapper(device=DEVICE, use_onnx=USE_ONNX)
395
 
 
433
 
434
  onnx_status = "enabled" if USE_ONNX else "disabled"
435
  onnx_codec_status = "active"
 
436
 
437
  if hasattr(app.state, 'tts_wrapper'):
438
  onnx_status = "active" if app.state.tts_wrapper.use_onnx else "fallback"
439
  onnx_codec_status = "active" if app.state.tts_wrapper.onnx_codec is not None else "inactive"
 
440
 
441
  return {
442
  "status": "healthy",
 
445
  "concurrency_limit": MAX_WORKERS,
446
  "onnx_optimization": onnx_status,
447
  "onnx_codec": onnx_codec_status,
 
448
  "memory_usage": {
449
  "total_gb": round(mem.total / (1024**3), 2),
450
  "used_percent": mem.percent
 
494
  audio_duration = len(audio_data) / SAMPLE_RATE
495
 
496
  onnx_codec_active = hasattr(app.state.tts_wrapper, 'onnx_codec') and app.state.tts_wrapper.onnx_codec is not None
 
497
 
498
+ logger.info(f"✅ Synthesis completed in {processing_time:.2f}s (ONNX Codec: {onnx_codec_active})")
499
 
500
  return Response(
501
  content=audio_bytes,
 
504
  "Content-Disposition": f"attachment; filename=tts_output.{output_format}",
505
  "X-Processing-Time": f"{processing_time:.2f}s",
506
  "X-Audio-Duration": f"{audio_duration:.2f}s",
507
+ "X-ONNX-Codec-Active": str(onnx_codec_active)
 
508
  }
509
  )
510
  except Exception as e:
 
546
  sentences = app.state.tts_wrapper._split_text_into_chunks(text)
547
 
548
  onnx_codec_active = hasattr(app.state.tts_wrapper, 'onnx_codec') and app.state.tts_wrapper.onnx_codec is not None
549
+ logger.info(f"Streaming {len(sentences)} chunks (ONNX Codec: {onnx_codec_active})")
 
 
550
 
551
  def process_chunk(sentence_text):
552
  with torch.no_grad():
 
583
  await producer_task
584
 
585
  onnx_codec_active = hasattr(app.state.tts_wrapper, 'onnx_codec') and app.state.tts_wrapper.onnx_codec is not None
 
586
 
587
  return StreamingResponse(
588
  stream_generator(),
589
  media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}",
590
  headers={
591
+ "X-ONNX-Codec-Active": str(onnx_codec_active)
 
592
  }
593
  )