Thanh-Lam commited on
Commit
f9308ba
·
1 Parent(s): c3418e9

Add multi-model support with PhoWhisper and model selection dropdown

Browse files
app.py CHANGED
@@ -1,316 +1,285 @@
1
  """
2
- Gradio Web Interface for Speaker Profiling
3
-
4
- Usage:
5
- python app.py
6
- python app.py --config configs/infer.yaml --share
7
  """
8
 
9
  import os
10
- import argparse
11
- import tempfile
12
- import time
13
- import numpy as np
14
  import torch
15
- import librosa
16
  import gradio as gr
17
  from pathlib import Path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- from src.models import MultiTaskSpeakerModel
20
- from src.utils import (
21
- setup_logging,
22
- get_logger,
23
- load_config,
24
- get_device,
25
- load_model_checkpoint,
26
- preprocess_audio
27
- )
28
 
29
 
30
- class SpeakerProfilerApp:
31
- """Gradio application for speaker profiling"""
32
 
33
- def __init__(self, config_path: str):
34
- self.logger = setup_logging(name="gradio_app")
35
- self.config = load_config(config_path)
36
- self.device = get_device(self.config['inference']['device'])
37
-
38
- self.sampling_rate = self.config['audio']['sampling_rate']
39
- self.max_duration = self.config['audio']['max_duration']
40
 
41
- self.gender_labels = self.config['labels']['gender']
42
- self.dialect_labels = self.config['labels']['dialect']
43
 
44
- self._load_model()
 
45
 
46
- def _load_model(self):
47
- """Load model and feature extractor"""
48
- from transformers import Wav2Vec2FeatureExtractor, WhisperFeatureExtractor
49
-
50
- self.logger.info("Loading model...")
51
-
52
- model_name = self.config['model']['name']
53
- is_ecapa = 'ecapa' in model_name.lower() or 'speechbrain' in model_name.lower()
54
-
55
- # Check if this is a Whisper/PhoWhisper model
56
- self.is_whisper = 'whisper' in model_name.lower() or 'phowhisper' in model_name.lower()
57
-
58
- if is_ecapa:
59
- # ECAPA-TDNN: use Wav2Vec2 feature extractor for audio normalization
60
- self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
61
- "facebook/wav2vec2-base"
62
- )
63
- elif self.is_whisper:
64
- # Whisper/PhoWhisper: use WhisperFeatureExtractor
65
- self.feature_extractor = WhisperFeatureExtractor.from_pretrained(
66
- model_name
67
- )
68
- else:
69
- self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
70
- self.config['model']['checkpoint']
71
- )
72
-
73
- self.model = MultiTaskSpeakerModel(model_name)
74
- self.model = load_model_checkpoint(
75
- self.model,
76
- self.config['model']['checkpoint'],
77
- str(self.device)
78
- )
79
-
80
- self.model.to(self.device)
81
- self.model.eval()
82
-
83
- self.logger.info(f"Model loaded on {self.device}")
84
 
85
- def predict(self, audio_input):
86
- """
87
- Predict gender and dialect from audio
88
-
89
- Args:
90
- audio_input: Tuple of (sample_rate, audio_array) from Gradio
91
-
92
- Returns:
93
- Tuple of (gender_result, dialect_result, details)
94
- """
95
- if audio_input is None:
96
- return "No audio", "No audio", "Please upload or record audio"
97
-
98
  try:
99
- sr, audio = audio_input
100
-
101
- if len(audio.shape) > 1:
102
- audio = audio.mean(axis=1)
103
-
104
- audio = audio.astype(np.float32)
105
- if audio.max() > 1.0:
106
- audio = audio / 32768.0
107
-
108
- if sr != self.sampling_rate:
109
- audio = librosa.resample(audio, orig_sr=sr, target_sr=self.sampling_rate)
110
 
111
- # Calculate original audio duration BEFORE preprocessing
112
- audio_duration = len(audio) / self.sampling_rate
113
-
114
- # Whisper requires 30 seconds of audio
115
- if self.is_whisper:
116
- max_duration = 30
117
  else:
118
- max_duration = self.max_duration
119
-
120
- audio = preprocess_audio(
121
- audio,
122
- sampling_rate=self.sampling_rate,
123
- max_duration=max_duration
124
- )
125
 
126
- # Whisper needs exactly 30 seconds - pad if necessary
127
- if self.is_whisper:
128
- target_len = self.sampling_rate * 30
129
- if len(audio) < target_len:
130
- audio = np.pad(audio, (0, target_len - len(audio)))
131
 
132
- inputs = self.feature_extractor(
133
- audio,
134
- sampling_rate=self.sampling_rate,
135
- return_tensors="pt",
136
- padding=True
 
137
  )
138
 
139
- # Whisper uses 'input_features', WavLM/HuBERT/Wav2Vec2 use 'input_values'
140
- if self.is_whisper:
141
- input_values = inputs.input_features.to(self.device)
 
 
 
142
  else:
143
- input_values = inputs.input_values.to(self.device)
 
 
 
 
 
 
 
 
144
 
145
- # Measure inference time
146
- start_time = time.perf_counter()
147
 
148
- with torch.no_grad():
149
- outputs = self.model(input_values)
150
- gender_logits = outputs['gender_logits']
151
- dialect_logits = outputs['dialect_logits']
152
 
153
- # Calculate inference time
154
- infer_time = (time.perf_counter() - start_time) * 1000 # Convert to ms
155
 
156
- gender_probs = torch.softmax(gender_logits, dim=-1).cpu().numpy()[0]
157
- dialect_probs = torch.softmax(dialect_logits, dim=-1).cpu().numpy()[0]
158
 
159
- gender_pred = int(np.argmax(gender_probs))
160
- dialect_pred = int(np.argmax(dialect_probs))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
- gender_name = self.gender_labels[gender_pred]
163
- dialect_name = self.dialect_labels[dialect_pred]
164
 
165
- gender_conf = gender_probs[gender_pred] * 100
166
- dialect_conf = dialect_probs[dialect_pred] * 100
 
167
 
168
- gender_result = f"{gender_name} ({gender_conf:.1f}%)"
169
- dialect_result = f"{dialect_name} ({dialect_conf:.1f}%)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
- details = self._format_details(gender_probs, dialect_probs, infer_time, audio_duration)
 
 
 
 
 
 
 
 
 
 
 
172
 
173
- self.logger.info(f"Prediction: Gender={gender_name}, Dialect={dialect_name} | Inference time: {infer_time:.2f}ms | Audio: {audio_duration:.2f}s")
 
174
 
175
- return gender_result, dialect_result, details
176
 
177
  except Exception as e:
178
- self.logger.error(f"Prediction error: {e}")
179
- return "Error", "Error", f"Error: {str(e)}"
 
180
 
181
- def _format_details(self, gender_probs: np.ndarray, dialect_probs: np.ndarray, infer_time: float = None, audio_duration: float = None) -> str:
182
- """Format detailed prediction results"""
183
- # Gender label names
184
- gender_names = ['Female', 'Male']
185
- # Dialect label names
186
- dialect_names = ['North', 'Central', 'South']
187
-
188
- lines = []
189
- lines.append("Gender Probabilities:")
190
- for i, name in enumerate(gender_names):
191
- lines.append(f" {name}: {gender_probs[i]*100:.2f}%")
192
-
193
- lines.append("")
194
- lines.append("Dialect Probabilities:")
195
- for i, name in enumerate(dialect_names):
196
- lines.append(f" {name}: {dialect_probs[i]*100:.2f}%")
197
-
198
- lines.append("")
199
- lines.append("─" * 30)
200
-
201
- if audio_duration is not None:
202
- lines.append(f"Audio Duration: {audio_duration:.2f} s")
203
-
204
- if infer_time is not None:
205
- lines.append(f"Inference Time: {infer_time:.2f} ms")
206
-
207
- return "\n".join(lines)
208
 
209
- def create_interface(self) -> gr.Blocks:
210
- """Create Gradio interface"""
211
-
212
- # Gradio < 4.0 doesn't support theme in Blocks
213
- with gr.Blocks(title="Vietnamese Speaker Profiling") as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
- gr.Markdown(
216
- """
217
- # Vietnamese Speaker Profiling
 
 
 
 
 
 
 
 
 
218
 
219
- Identify gender and dialect from Vietnamese speech audio.
 
 
 
 
 
220
 
221
- **Model:** Encoder + Attentive Pooling + LayerNorm + MultiHead Classifier
222
 
223
- **Supported dialects:** North, Central, South
224
- """
225
- )
226
 
227
- with gr.Row():
228
- with gr.Column(scale=1):
229
- audio_input = gr.Audio(
230
- label="Input Audio",
231
- type="numpy",
232
- sources=["upload", "microphone"]
233
- )
234
-
235
- submit_btn = gr.Button("Analyze", variant="primary")
236
- clear_btn = gr.Button("Clear")
237
 
238
- with gr.Column(scale=1):
239
- gender_output = gr.Textbox(
240
- label="Gender",
241
- interactive=False
242
- )
243
- dialect_output = gr.Textbox(
244
- label="Dialect",
245
- interactive=False
246
- )
247
- details_output = gr.Textbox(
248
- label="Details",
249
- lines=8,
250
- interactive=False
251
- )
252
-
253
- gr.Markdown(
254
- """
255
- ---
256
- **Notes:**
257
- - Supported formats: WAV, MP3
258
- - Recommended duration: 3-10 seconds
259
- """
260
- )
261
-
262
- submit_btn.click(
263
- fn=self.predict,
264
- inputs=[audio_input],
265
- outputs=[gender_output, dialect_output, details_output]
266
- )
267
-
268
- clear_btn.click(
269
- fn=lambda: (None, "", "", ""),
270
- inputs=[],
271
- outputs=[audio_input, gender_output, dialect_output, details_output]
272
- )
273
 
274
- return demo
275
-
276
-
277
- def main():
278
- """Main function"""
279
- parser = argparse.ArgumentParser(description="Speaker Profiling Web Interface")
280
- parser.add_argument(
281
- "--config",
282
- type=str,
283
- default="configs/infer.yaml",
284
- help="Path to config file"
285
- )
286
- parser.add_argument(
287
- "--share",
288
- action="store_true",
289
- help="Create public link"
290
- )
291
- parser.add_argument(
292
- "--port",
293
- type=int,
294
- default=7860,
295
- help="Port number (default: 7860)"
296
- )
297
- parser.add_argument(
298
- "--server_name",
299
- type=str,
300
- default="0.0.0.0",
301
- help="Server name (default: 0.0.0.0)"
302
- )
303
- args = parser.parse_args()
304
-
305
- app = SpeakerProfilerApp(args.config)
306
- demo = app.create_interface()
307
 
308
- demo.launch(
309
- server_name=args.server_name,
310
- server_port=args.port,
311
- share=args.share
312
- )
313
 
314
 
315
  if __name__ == "__main__":
316
- main()
 
 
1
  """
2
+ Vietnamese Speaker Profiling - Multi-Model Gradio Interface
3
+ Supports: Vietnamese Wav2Vec2 and PhoWhisper encoders
 
 
 
4
  """
5
 
6
  import os
 
 
 
 
7
  import torch
8
+ import torchaudio
9
  import gradio as gr
10
  from pathlib import Path
11
+ from safetensors.torch import load_file as load_safetensors
12
+
13
+ # Model configurations
14
+ MODELS_CONFIG = {
15
+ "Wav2Vec2 Vietnamese": {
16
+ "path": "model/vulehuubinh",
17
+ "encoder_name": "nguyenvulebinh/wav2vec2-base-vi-vlsp2020",
18
+ "is_whisper": False,
19
+ "description": "Vietnamese Wav2Vec2 pretrained model - Fast inference"
20
+ },
21
+ "PhoWhisper": {
22
+ "path": "model/pho",
23
+ "encoder_name": "vinai/PhoWhisper-base",
24
+ "is_whisper": True,
25
+ "description": "Vietnamese Whisper model - Higher accuracy"
26
+ }
27
+ }
28
 
29
+ # Labels
30
+ GENDER_LABELS = ["Male", "Female"]
31
+ DIALECT_LABELS = ["Northern", "Central", "Southern"]
 
 
 
 
 
 
32
 
33
 
34
+ class MultiModelProfiler:
35
+ """Speaker Profiler supporting multiple encoder models."""
36
 
37
+ def __init__(self):
38
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+ self.sampling_rate = 16000
40
+ self.models = {}
41
+ self.processors = {}
42
+ self.current_model = None
 
43
 
44
+ print(f"Using device: {self.device}")
 
45
 
46
+ # Pre-load all models
47
+ self._load_all_models()
48
 
49
+ def _load_all_models(self):
50
+ """Load all available models."""
51
+ for model_name, config in MODELS_CONFIG.items():
52
+ model_path = Path(config["path"])
53
+ if model_path.exists():
54
+ print(f"Loading {model_name}...")
55
+ self._load_single_model(model_name, config)
56
+ else:
57
+ print(f"Model not found: {model_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ def _load_single_model(self, model_name: str, config: dict):
60
+ """Load a specific model."""
 
 
 
 
 
 
 
 
 
 
 
61
  try:
62
+ model_path = Path(config["path"])
63
+ is_whisper = config["is_whisper"]
64
+ encoder_name = config["encoder_name"]
 
 
 
 
 
 
 
 
65
 
66
+ # Load processor
67
+ if is_whisper:
68
+ from transformers import WhisperFeatureExtractor
69
+ processor = WhisperFeatureExtractor.from_pretrained(encoder_name)
 
 
70
  else:
71
+ from transformers import Wav2Vec2FeatureExtractor
72
+ processor = Wav2Vec2FeatureExtractor.from_pretrained(encoder_name)
 
 
 
 
 
73
 
74
+ # Load model
75
+ from src.models import SpeakerProfileModel
 
 
 
76
 
77
+ model = SpeakerProfileModel(
78
+ model_name=encoder_name,
79
+ num_gender_classes=2,
80
+ num_dialect_classes=3,
81
+ dropout=0.1,
82
+ freeze_encoder=True
83
  )
84
 
85
+ # Load checkpoint from safetensors
86
+ checkpoint_path = model_path / "model.safetensors"
87
+ if checkpoint_path.exists():
88
+ state_dict = load_safetensors(str(checkpoint_path))
89
+ model.load_state_dict(state_dict)
90
+ print(f"Loaded checkpoint: {checkpoint_path}")
91
  else:
92
+ # Try loading from .pt file
93
+ pt_path = model_path / "best_model.pt"
94
+ if pt_path.exists():
95
+ checkpoint = torch.load(pt_path, map_location=self.device, weights_only=False)
96
+ if "model_state_dict" in checkpoint:
97
+ model.load_state_dict(checkpoint["model_state_dict"])
98
+ else:
99
+ model.load_state_dict(checkpoint)
100
+ print(f"Loaded checkpoint: {pt_path}")
101
 
102
+ model.to(self.device)
103
+ model.eval()
104
 
105
+ self.models[model_name] = model
106
+ self.processors[model_name] = processor
 
 
107
 
108
+ if self.current_model is None:
109
+ self.current_model = model_name
110
 
111
+ print(f"✓ {model_name} loaded successfully")
 
112
 
113
+ except Exception as e:
114
+ print(f"✗ Error loading {model_name}: {e}")
115
+ import traceback
116
+ traceback.print_exc()
117
+
118
+ def predict(self, audio_path: str, model_name: str):
119
+ """Predict gender and dialect from audio."""
120
+ if model_name not in self.models:
121
+ available = list(self.models.keys())
122
+ if not available:
123
+ return "No models available", "No models available"
124
+ model_name = available[0]
125
+
126
+ try:
127
+ model = self.models[model_name]
128
+ processor = self.processors[model_name]
129
+ is_whisper = MODELS_CONFIG[model_name]["is_whisper"]
130
 
131
+ # Load audio
132
+ waveform, sr = torchaudio.load(audio_path)
133
 
134
+ # Convert to mono
135
+ if waveform.shape[0] > 1:
136
+ waveform = waveform.mean(dim=0, keepdim=True)
137
 
138
+ # Resample if needed
139
+ if sr != self.sampling_rate:
140
+ resampler = torchaudio.transforms.Resample(sr, self.sampling_rate)
141
+ waveform = resampler(waveform)
142
+
143
+ waveform = waveform.squeeze(0).numpy()
144
+
145
+ # Process based on model type
146
+ if is_whisper:
147
+ # Whisper requires exactly 30 seconds of audio
148
+ whisper_length = self.sampling_rate * 30 # 480000 samples
149
+ if len(waveform) < whisper_length:
150
+ waveform_padded = torch.nn.functional.pad(
151
+ torch.tensor(waveform),
152
+ (0, whisper_length - len(waveform))
153
+ ).numpy()
154
+ else:
155
+ waveform_padded = waveform[:whisper_length]
156
+
157
+ inputs = processor(
158
+ waveform_padded,
159
+ sampling_rate=self.sampling_rate,
160
+ return_tensors="pt"
161
+ )
162
+ input_tensor = inputs.input_features.to(self.device)
163
+ else:
164
+ # Wav2Vec2 uses raw waveform
165
+ inputs = processor(
166
+ waveform,
167
+ sampling_rate=self.sampling_rate,
168
+ return_tensors="pt",
169
+ padding=True
170
+ )
171
+ input_tensor = inputs.input_values.to(self.device)
172
 
173
+ # Inference
174
+ with torch.no_grad():
175
+ gender_logits, dialect_logits = model(input_tensor)
176
+
177
+ gender_probs = torch.softmax(gender_logits, dim=-1)
178
+ dialect_probs = torch.softmax(dialect_logits, dim=-1)
179
+
180
+ gender_idx = gender_probs.argmax(dim=-1).item()
181
+ dialect_idx = dialect_probs.argmax(dim=-1).item()
182
+
183
+ gender_conf = gender_probs[0, gender_idx].item() * 100
184
+ dialect_conf = dialect_probs[0, dialect_idx].item() * 100
185
 
186
+ gender_result = f"{GENDER_LABELS[gender_idx]} ({gender_conf:.1f}%)"
187
+ dialect_result = f"{DIALECT_LABELS[dialect_idx]} ({dialect_conf:.1f}%)"
188
 
189
+ return gender_result, dialect_result
190
 
191
  except Exception as e:
192
+ import traceback
193
+ traceback.print_exc()
194
+ return f"Error: {str(e)}", f"Error: {str(e)}"
195
 
196
+ def get_available_models(self):
197
+ """Get list of available models."""
198
+ return list(self.models.keys())
199
+
200
+
201
+ def create_interface():
202
+ """Create Gradio interface with model selection."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
+ profiler = MultiModelProfiler()
205
+ available_models = profiler.get_available_models()
206
+
207
+ if not available_models:
208
+ available_models = ["No models available"]
209
+
210
+ def predict_wrapper(audio, model_name):
211
+ if audio is None:
212
+ return "Please upload audio", "Please upload audio"
213
+ return profiler.predict(audio, model_name)
214
+
215
+ # Create model info text
216
+ model_info = ""
217
+ for name, config in MODELS_CONFIG.items():
218
+ status = "✓" if name in profiler.models else "✗"
219
+ model_info += f"{status} **{name}**: {config['description']}\n"
220
+
221
+ with gr.Blocks(title="Vietnamese Speaker Profiling", theme=gr.themes.Soft()) as demo:
222
+ gr.Markdown(
223
+ """
224
+ # 🎙️ Vietnamese Speaker Profiling
225
+
226
+ Analyze Vietnamese speech to predict **Gender** and **Dialect Region**.
227
 
228
+ Supports multiple AI models - choose the one that works best for you!
229
+ """
230
+ )
231
+
232
+ with gr.Row():
233
+ with gr.Column(scale=1):
234
+ gr.Markdown("### 📤 Input")
235
+ audio_input = gr.Audio(
236
+ label="Upload or Record Audio",
237
+ type="filepath",
238
+ sources=["upload", "microphone"]
239
+ )
240
 
241
+ model_dropdown = gr.Dropdown(
242
+ choices=available_models,
243
+ value=available_models[0] if available_models else None,
244
+ label="🤖 Select Model",
245
+ info="Choose the AI model for analysis"
246
+ )
247
 
248
+ submit_btn = gr.Button("🔍 Analyze", variant="primary", size="lg")
249
 
250
+ gr.Markdown("### ℹ️ Available Models")
251
+ gr.Markdown(model_info)
 
252
 
253
+ with gr.Column(scale=1):
254
+ gr.Markdown("### 📊 Results")
255
+ gender_output = gr.Textbox(label="👤 Gender", interactive=False)
256
+ dialect_output = gr.Textbox(label="🗣️ Dialect Region", interactive=False)
 
 
 
 
 
 
257
 
258
+ gr.Markdown(
259
+ """
260
+ ### 📖 Dialect Regions
261
+ - **Northern**: Hanoi and surrounding areas
262
+ - **Central**: Huế, Đà Nẵng, and Central Vietnam
263
+ - **Southern**: Ho Chi Minh City and Southern Vietnam
264
+ """
265
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
+ submit_btn.click(
268
+ fn=predict_wrapper,
269
+ inputs=[audio_input, model_dropdown],
270
+ outputs=[gender_output, dialect_output]
271
+ )
272
+
273
+ gr.Markdown(
274
+ """
275
+ ---
276
+ *Made with ❤️ for Vietnamese Speech Processing Research*
277
+ """
278
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
 
280
+ return demo
 
 
 
 
281
 
282
 
283
  if __name__ == "__main__":
284
+ demo = create_interface()
285
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
model/pho/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e14ed1dd391d230ba74d231d164e626c1a9e9b865d0c56a87af4351e92b9557
3
+ size 292648364
model/pho/preprocessor_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chunk_length": 30,
3
+ "feature_extractor_type": "WhisperFeatureExtractor",
4
+ "feature_size": 80,
5
+ "hop_length": 160,
6
+ "n_fft": 400,
7
+ "n_samples": 480000,
8
+ "nb_max_frames": 3000,
9
+ "padding_side": "right",
10
+ "padding_value": 0.0,
11
+ "processor_class": "WhisperProcessor",
12
+ "return_attention_mask": false,
13
+ "sampling_rate": 16000
14
+ }
model/pho/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8581872846c31ff58536a0780aed646dd2c25671e9318390007d5784c62dc39d
3
+ size 5176