Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| # | |
| from transformers import Wav2Vec2FeatureExtractor | |
| from transformers import AutoModel | |
| import torch | |
| from torch import nn | |
| import torchaudio | |
| import torchaudio.transforms as T | |
| import logging | |
| import json | |
| import importlib | |
| modeling_MERT = importlib.import_module("MERT-v0-public.modeling_MERT") | |
| from Prediction_Head.MTGGenre_head import MLPProberBase | |
| # input cr: https://huggingface.co/spaces/thealphhamerc/audio-to-text/blob/main/app.py | |
| logger = logging.getLogger("whisper-jax-app") | |
| logger.setLevel(logging.INFO) | |
| ch = logging.StreamHandler() | |
| ch.setLevel(logging.INFO) | |
| formatter = logging.Formatter( | |
| "%(asctime)s;%(levelname)s;%(message)s", "%Y-%m-%d %H:%M:%S") | |
| ch.setFormatter(formatter) | |
| logger.addHandler(ch) | |
| inputs = [ | |
| gr.components.Audio(type="filepath", label="Add music audio file"), | |
| gr.inputs.Audio(source="microphone", type="filepath"), | |
| ] | |
| live_inputs = [ | |
| gr.Audio(source="microphone",streaming=True, type="filepath"), | |
| ] | |
| # outputs = [gr.components.Textbox()] | |
| # outputs = [gr.components.Textbox(), transcription_df] | |
| title = "Predict the top 5 possible genres and tags of Music" | |
| description = "An example of using map/MERT-95M-public model as backbone to conduct music genre/tagging predcition." | |
| article = "" | |
| audio_examples = [ | |
| # ["input/example-1.wav"], | |
| # ["input/example-2.wav"], | |
| ] | |
| # Load the model and the corresponding preprocessor config | |
| # model = AutoModel.from_pretrained("m-a-p/MERT-v0-public", trust_remote_code=True) | |
| # processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v0-public",trust_remote_code=True) | |
| model = modeling_MERT.MERTModel.from_pretrained("./MERT-v0-public") | |
| processor = Wav2Vec2FeatureExtractor.from_pretrained("./MERT-v0-public") | |
| MERT_LAYER_IDX = 7 | |
| MTGGenre_classifier = MLPProberBase() | |
| MTGGenre_classifier.load_state_dict(torch.load('Prediction_Head/best_MTGGenre.ckpt')['state_dict']) | |
| with open('Prediction_Head/MTGGenre_id2class.json', 'r') as f: | |
| id2cls=json.load(f) | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| model.to(device) | |
| MTGGenre_classifier.to(device) | |
| def convert_audio(inputs, microphone): | |
| if (microphone is not None): | |
| inputs = microphone | |
| waveform, sample_rate = torchaudio.load(inputs) | |
| resample_rate = processor.sampling_rate | |
| # make sure the sample_rate aligned | |
| if resample_rate != sample_rate: | |
| print(f'setting rate from {sample_rate} to {resample_rate}') | |
| resampler = T.Resample(sample_rate, resample_rate) | |
| waveform = resampler(waveform) | |
| waveform = waveform.view(-1,) # make it (n_sample, ) | |
| model_inputs = processor(waveform, sampling_rate=resample_rate, return_tensors="pt") | |
| model_inputs.to(device) | |
| with torch.no_grad(): | |
| model_outputs = model(**model_inputs, output_hidden_states=True) | |
| # take a look at the output shape, there are 13 layers of representation | |
| # each layer performs differently in different downstream tasks, you should choose empirically | |
| all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze() | |
| print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim] | |
| logits = MTGGenre_classifier(torch.mean(all_layer_hidden_states[MERT_LAYER_IDX], dim=0)) # [1, 87] | |
| print(logits.shape) | |
| sorted_idx = torch.argsort(logits, dim = -1, descending=True) | |
| output_texts = "\n".join([id2cls[str(idx.item())].replace('genre---', '') for idx in sorted_idx[:5]]) | |
| # logger.warning(all_layer_hidden_states.shape) | |
| # return f"device {device}, sample reprensentation: {str(all_layer_hidden_states[12, 0, :10])}" | |
| return f"device: {device}\n" + output_texts | |
| def live_convert_audio(microphone): | |
| if (microphone is not None): | |
| inputs = microphone | |
| waveform, sample_rate = torchaudio.load(inputs) | |
| resample_rate = processor.sampling_rate | |
| # make sure the sample_rate aligned | |
| if resample_rate != sample_rate: | |
| print(f'setting rate from {sample_rate} to {resample_rate}') | |
| resampler = T.Resample(sample_rate, resample_rate) | |
| waveform = resampler(waveform) | |
| waveform = waveform.view(-1,) # make it (n_sample, ) | |
| model_inputs = processor(waveform, sampling_rate=resample_rate, return_tensors="pt") | |
| model_inputs.to(device) | |
| with torch.no_grad(): | |
| model_outputs = model(**model_inputs, output_hidden_states=True) | |
| # take a look at the output shape, there are 13 layers of representation | |
| # each layer performs differently in different downstream tasks, you should choose empirically | |
| all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze() | |
| print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim] | |
| logits = MTGGenre_classifier(torch.mean(all_layer_hidden_states[MERT_LAYER_IDX], dim=0)) # [1, 87] | |
| print(logits.shape) | |
| sorted_idx = torch.argsort(logits, dim = -1, descending=True) | |
| output_texts = "\n".join([id2cls[str(idx.item())].replace('genre---', '') for idx in sorted_idx[:5]]) | |
| # logger.warning(all_layer_hidden_states.shape) | |
| # return f"device {device}, sample reprensentation: {str(all_layer_hidden_states[12, 0, :10])}" | |
| return f"device: {device}\n" + output_texts | |
| audio_chunked = gr.Interface( | |
| fn=convert_audio, | |
| inputs=inputs, | |
| outputs=[gr.components.Textbox()], | |
| allow_flagging="never", | |
| title=title, | |
| description=description, | |
| article=article, | |
| examples=audio_examples, | |
| ) | |
| live_audio_chunked = gr.Interface( | |
| fn=live_convert_audio, | |
| inputs=live_inputs, | |
| outputs=[gr.components.Textbox()], | |
| allow_flagging="never", | |
| title=title, | |
| description=description, | |
| article=article, | |
| # examples=audio_examples, | |
| live=True, | |
| ) | |
| demo = gr.Blocks() | |
| with demo: | |
| gr.TabbedInterface( | |
| [ | |
| audio_chunked, | |
| live_audio_chunked, | |
| ], | |
| [ | |
| "Audio File or Recording", | |
| "Live Streaming Music" | |
| ] | |
| ) | |
| demo.queue(concurrency_count=1, max_size=5) | |
| demo.launch(show_api=False) |