Mohamedenzeyad commited on
Commit
02f871d
·
verified ·
1 Parent(s): a3783dd

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +311 -40
src/streamlit_app.py CHANGED
@@ -1,40 +1,311 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
- import streamlit as st
5
-
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import csv
4
+ import subprocess
5
+ import streamlit as st
6
+ import numpy as np
7
+ import pandas as pd
8
+ import tensorflow as tf
9
+ import tensorflow_hub as hub
10
+ import matplotlib.pyplot as plt
11
+ from tensorflow import keras
12
+ from huggingface_hub import from_pretrained_keras
13
+ from audio_recorder_streamlit import audio_recorder
14
+ import yt_dlp
15
+ import torch
16
+ import torchaudio
17
+ torchaudio.set_audio_backend("soundfile")
18
+ import speechbrain
19
+
20
+ # Check if SpeechBrain is installed, if not display a message
21
+ try:
22
+ from speechbrain.pretrained import EncoderClassifier
23
+ from speechbrain.pretrained.interfaces import foreign_class
24
+ speechbrain_available = True
25
+ except ImportError:
26
+ speechbrain_available = False
27
+
28
+ st.set_page_config(
29
+ page_title="English Accent Classification",
30
+ page_icon="🎙️",
31
+ layout="wide"
32
+ )
33
+
34
+ # Configuration
35
+ xlsr_accent_classes = [
36
+ "US",
37
+ "England",
38
+ "Australia",
39
+ "Indian",
40
+ "Canada",
41
+ "Bermuda",
42
+ "Scotland",
43
+ "African",
44
+ "Ireland",
45
+ "NewZealand",
46
+ "Wales",
47
+ "Malaysia",
48
+ "Philippines",
49
+ "Singapore",
50
+ "HongKong",
51
+ "SouthAtlantic"
52
+ ]
53
+
54
+ @st.cache_resource
55
+ def load_models():
56
+ xlsr_model = None
57
+
58
+ try:
59
+ # Show loading message for XLSR
60
+ with st.spinner("Loading XLSR-based accent classifier..."):
61
+ xlsr_model = foreign_class(
62
+ source="Jzuluaga/accent-id-commonaccent_xlsr-en-english",
63
+ pymodule_file="custom_interface.py",
64
+ classname="CustomEncoderWav2vec2Classifier",
65
+ savedir="pretrained_models/accent-id-commonaccent_xlsr-en-english"
66
+ )
67
+ except Exception as e:
68
+ st.warning(f"Could not load XLSR model: {e}")
69
+ xlsr_model = None
70
+
71
+ return xlsr_model
72
+
73
+ # Function to check if ffmpeg is installed
74
+ def is_ffmpeg_installed():
75
+ """Checks if ffmpeg is installed and in the PATH."""
76
+ try:
77
+ subprocess.run(['ffmpeg', '-version'], capture_output=True, check=True)
78
+ return True
79
+ except (subprocess.CalledProcessError, FileNotFoundError) as e:
80
+ st.error(f"FFmpeg check failed: {e}")
81
+ return False
82
+
83
+ # Function to extract audio from YouTube URL
84
+ def extract_audio(video_url, output_audio_path="audio.wav"):
85
+ """
86
+ Downloads video from URL, extracts audio using ffmpeg, and saves it as a WAV file.
87
+ """
88
+ if not is_ffmpeg_installed():
89
+ st.error("FFmpeg is not installed or not in your system's PATH.")
90
+ st.info("Please install FFmpeg. You can download it from [FFmpeg](https://ffmpeg.org/download.html)")
91
+ return False
92
+
93
+ ydl_opts = {
94
+ 'format': 'bestaudio/best',
95
+ 'postprocessors': [{
96
+ 'key': 'FFmpegExtractAudio',
97
+ 'preferredcodec': 'wav',
98
+ }],
99
+ 'outtmpl': 'temp_video.%(ext)s',
100
+ }
101
+ try:
102
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
103
+ info_dict = ydl.extract_info(video_url, download=True)
104
+ video_filepath = ydl.prepare_filename(info_dict)
105
+ # yt-dlp with FFmpegExtractAudio should directly output the audio file
106
+ # The output file will have the same name as the video but with .wav extension
107
+ base, _ = os.path.splitext(video_filepath)
108
+ audio_filepath = base + '.wav'
109
+
110
+ # Rename the output file to the desired output_audio_path
111
+ if os.path.exists(audio_filepath):
112
+ # Use copy instead of rename to avoid issues if files are on different file systems
113
+ import shutil
114
+ shutil.copy2(audio_filepath, output_audio_path)
115
+ os.remove(audio_filepath) # Remove the original after copying
116
+ st.success(f"Audio extracted successfully to {output_audio_path}")
117
+ else:
118
+ st.error(f"Error: Audio file not found after extraction.")
119
+ return False
120
+
121
+ # Clean up the temporary video file if it still exists (sometimes it doesn't)
122
+ if os.path.exists(video_filepath):
123
+ os.remove(video_filepath)
124
+ print(f"Cleaned up temporary video file {video_filepath}")
125
+
126
+ return True
127
+
128
+ except Exception as e:
129
+ st.error(f"An error occurred during audio extraction: {e}")
130
+ return False
131
+
132
+ # Function that reads a wav audio file - without tensorflow-io
133
+ def load_16k_audio_wav(filename):
134
+ """Read and resample audio file to 16kHz without using tensorflow-io."""
135
+ # Use ffmpeg to resample the audio file to 16kHz
136
+ output_filename = "resampled_16k.wav"
137
+
138
+ try:
139
+ subprocess.run([
140
+ 'ffmpeg', '-y', '-i', filename, '-ar', '16000', '-ac', '1', output_filename
141
+ ], check=True, capture_output=True)
142
+
143
+ # Read the resampled file
144
+ audio, sample_rate = tf.audio.decode_wav(tf.io.read_file(output_filename))
145
+ audio = tf.squeeze(audio, axis=-1)
146
+
147
+ # Clean up
148
+ if os.path.exists(output_filename):
149
+ os.remove(output_filename)
150
+
151
+ return audio
152
+ except Exception as e:
153
+ st.error(f"Error resampling audio: {e}")
154
+ # Fallback to just decoding without resampling
155
+ audio, _ = tf.audio.decode_wav(tf.io.read_file(filename))
156
+ audio = tf.squeeze(audio, axis=-1)
157
+ return audio
158
+
159
+ # Function that takes a recorded audio array and returns a tensor
160
+ def recorded_audio_to_tensor(audio_bytes):
161
+ # Save the audio bytes to a temporary file
162
+ temp_path = "temp_recorded_audio.wav"
163
+ with open(temp_path, "wb") as f:
164
+ f.write(audio_bytes)
165
+
166
+ # Load the audio file as a tensor
167
+ audio_tensor = load_16k_audio_wav(temp_path)
168
+
169
+ # Clean up
170
+ if os.path.exists(temp_path):
171
+ os.remove(temp_path)
172
+
173
+ return audio_tensor
174
+
175
+ # Function to use XLSR model for accent classification
176
+ def predict_accent_with_xlsr(audio_file_path, xlsr_model):
177
+ try:
178
+ # Classify the audio file
179
+ out_prob, score, index, text_lab = xlsr_model.classify_file(audio_file_path)
180
+
181
+ # Convert the prediction tensor to numpy for easier handling
182
+ probs = out_prob.squeeze().numpy()
183
+
184
+ # Create a dictionary of accent probabilities
185
+ accent_probs = {xlsr_accent_classes[i]: float(probs[i]) for i in range(len(xlsr_accent_classes))}
186
+
187
+ # Get the predicted accent
188
+ predicted_accent = text_lab
189
+ confidence = float(score)
190
+
191
+ return predicted_accent, confidence, accent_probs
192
+ except Exception as e:
193
+ st.error(f"Error with XLSR prediction: {e}")
194
+ return None, None, None
195
+
196
+ def main():
197
+ st.title("English Speaker Accent Recognition")
198
+ st.subheader("Classify English accents using XLSR Wav2Vec 2.0")
199
+
200
+ st.write("""
201
+ This application detects and classifies English accents using the XLSR Wav2Vec 2.0 model.
202
+ """)
203
+
204
+ # Load models
205
+ xlsr_model = load_models()
206
+
207
+ # Check if ffmpeg is installed
208
+ if not is_ffmpeg_installed():
209
+ st.warning("FFmpeg is not installed. You won't be able to use YouTube URLs or process some audio files correctly.")
210
+ st.info("Please install FFmpeg. You can download it from [FFmpeg](https://ffmpeg.org/download.html)")
211
+
212
+ # Create tabs for different input methods
213
+ tab3 = st.tabs(["YouTube URL"])[0]
214
+
215
+ with tab3:
216
+ youtube_url = st.text_input("Enter YouTube URL", placeholder="https://www.youtube.com/watch?v=...")
217
+
218
+ if youtube_url:
219
+ if st.button("Extract Audio from YouTube", key="extract_btn"):
220
+ with st.spinner("Extracting audio from YouTube..."):
221
+ output_path = "youtube_audio.wav"
222
+ if extract_audio(youtube_url, output_path):
223
+ st.success("Audio extracted successfully!")
224
+ st.audio(output_path, format="audio/wav")
225
+ st.session_state.youtube_audio_path = output_path
226
+ else:
227
+ st.error("Failed to extract audio from YouTube URL.")
228
+
229
+ # Process and analyze the audio when the button is clicked
230
+ if st.button("Predict Accent", type="primary"):
231
+ audio_file_path = None
232
+
233
+ # Check which audio source we have
234
+ if 'youtube_audio_path' in st.session_state and os.path.exists(st.session_state.youtube_audio_path):
235
+ audio_file_path = st.session_state.youtube_audio_path
236
+ else:
237
+ st.warning("Please provide a YouTube URL.")
238
+ st.stop()
239
+
240
+ # Run prediction based on selected model
241
+ if xlsr_model is not None:
242
+ with st.spinner("Analyzing audio with XLSR Wav2Vec 2.0..."):
243
+ xlsr_predicted_accent, xlsr_confidence, xlsr_accent_probs = predict_accent_with_xlsr(
244
+ audio_file_path, xlsr_model
245
+ )
246
+
247
+ if xlsr_predicted_accent:
248
+ st.success(f"🎯 **Predicted Accent: {xlsr_predicted_accent}** (Confidence: {xlsr_confidence:.2f})")
249
+
250
+ # Create visualization for XLSR results
251
+ sorted_probs = {k: v for k, v in sorted(xlsr_accent_probs.items(), key=lambda item: item[1], reverse=True)}
252
+
253
+ # Create a bar chart
254
+ fig, ax = plt.subplots(figsize=(10, 6))
255
+ accents = list(sorted_probs.keys())
256
+ probabilities = list(sorted_probs.values())
257
+
258
+ ax.bar(accents, probabilities, color='lightcoral')
259
+ ax.set_ylabel('Probability')
260
+ ax.set_title('XLSR Wav2Vec 2.0 Accent Probabilities (16 English Accents)')
261
+ plt.xticks(rotation=45)
262
+ plt.tight_layout()
263
+
264
+ st.pyplot(fig)
265
+
266
+ # Also display as a table
267
+ df = pd.DataFrame({
268
+ 'Accent': accents,
269
+ 'Probability': [f"{p:.2%}" for p in probabilities]
270
+ })
271
+ st.dataframe(df, hide_index=True)
272
+
273
+ # Add information about XLSR model
274
+ st.info("""
275
+ 🚀 **XLSR Wav2Vec 2.0 Model**: This state-of-the-art model achieves up to 95% accuracy
276
+ and can distinguish between 16 different English accent regions including specialized
277
+ accents like Bermuda, Hong Kong, and South Atlantic varieties.
278
+ """)
279
+ else:
280
+ st.error("XLSR model failed to classify the accent.")
281
+
282
+ # Clean up temporary files
283
+ if audio_file_path and audio_file_path.startswith("temp_") and os.path.exists(audio_file_path):
284
+ os.remove(audio_file_path)
285
+
286
+ # Add information about the models
287
+ st.markdown("---")
288
+ st.subheader("About the Model")
289
+
290
+ st.markdown("### XLSR Wav2Vec 2.0 ⭐")
291
+ st.write("""
292
+ **State-of-the-art** model with 95% accuracy for English accent classification.
293
+ **Supported accents:**
294
+ - US, England, Australia, India
295
+ - Canada, Bermuda, Scotland, Africa
296
+ - Ireland, New Zealand, Wales
297
+ - Malaysia, Philippines, Singapore
298
+ - Hong Kong, South Atlantic
299
+
300
+ Based on self-supervised Wav2Vec 2.0 with cross-lingual representations.
301
+ """)
302
+
303
+ # Credits
304
+ st.markdown("---")
305
+ st.markdown("""
306
+ **Credits:** - **XLSR Model**: [Jzuluaga/accent-id-commonaccent_xlsr-en-english](https://huggingface.co/Jzuluaga/accent-id-commonaccent_xlsr-en-english) by Juan Zuluaga-Gomez et al.
307
+ - All SpeechBrain models by [SpeechBrain](https://speechbrain.github.io/)
308
+ """)
309
+
310
+ if __name__ == "__main__":
311
+ main()