| import gradio as gr | |
| import plotly.express as px | |
| import pandas as pd | |
| import logging | |
| import whisper | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import pandas as pd | |
| from torch.nn.functional import silu | |
| from torch.nn.functional import softplus | |
| from einops import rearrange, repeat, einsum | |
| from transformers import AutoTokenizer, AutoModel | |
| from torch import Tensor | |
| from einops import rearrange | |
| from model import Mamba | |
| logging.basicConfig(level=logging.INFO) | |
| def plotly_plot_text(text): | |
| data = pd.DataFrame() | |
| data['Emotion'] = ['π anger', 'π€’ disgust', 'π¨ fear', 'π joy/happiness', 'π neutral', 'π’ sadness', 'π² surprise/enthusiasm'] | |
| data['Probability'] = model.predict_proba([text])[0].tolist() | |
| p = px.bar(data, x='Emotion', y='Probability', color="Probability") | |
| return ( | |
| p, | |
| f"π£οΈ Transcription:\n{text}", | |
| f"## π Dominant Emotion: {data['Emotion'].values[np.argmax(np.array(data['Probability']))]}" | |
| ) | |
| def transcribe_audio(audio_path): | |
| whisper_model = whisper.load_model("base") | |
| try: | |
| result = whisper_model.transcribe(audio_path, fp16=False) | |
| return result.get('text', '') | |
| except Exception as e: | |
| logging.error(f"Transcription failed: {e}") | |
| return "" | |
| def plotly_plot_audio(audio_path): | |
| data = pd.DataFrame() | |
| data['Emotion'] = ['π anger', 'π€’ disgust', 'π¨ fear', 'π joy/happiness', 'π neutral', 'π’ sadness', 'π² surprise/enthusiasm'] | |
| try: | |
| text = transcribe_audio(audio_path) | |
| data['Probability'] = model.predict_proba([text])[0].tolist() if text.strip() else [0.0] * data.shape[0] | |
| p = px.bar(data, x='Emotion', y='Probability', color="Probability") | |
| return ( | |
| p, | |
| f"## βοΈ Dominant Emotion: {data['Emotion'].values[np.argmax(np.array(data['Probability']))]}" | |
| ) | |
| except Exception as e: | |
| logging.error(f"Processing failed: {e}") | |
| data['Probability'] = [0] * data.shape[0] | |
| p = px.bar(data, x='Emotion', y='Probability', color="Probability") | |
| return ( | |
| p, | |
| "β Error processing audio", | |
| "β οΈ Processing Error" | |
| ) | |
| def plotly_plot_audio(audio_path): | |
| data = pd.DataFrame() | |
| data['Emotion'] = ['π anger', 'π€’ disgust', 'π¨ fear', 'π joy/happiness', 'π neutral', 'π’ sadness', 'π² surprise/enthusiasm'] | |
| try: | |
| text = transcribe_audio(audio_path) | |
| data['Probability'] = model.predict_proba([text])[0].tolist() if text.strip() else [0.0] * data.shape[0] | |
| p = px.bar(data, x='Emotion', y='Probability', color="Probability") | |
| return ( | |
| p, | |
| f"π€ Transcription:\n{text}", | |
| f"## βοΈ Dominant Emotion: {data['Emotion'].values[np.argmax(np.array(data['Probability']))]}" | |
| ) | |
| except Exception as e: | |
| logging.error(f"Processing failed: {e}") | |
| data['Probability'] = [0] * data.shape[0] | |
| p = px.bar(data, x='Emotion', y='Probability', color="Probability") | |
| return ( | |
| p, | |
| "β Error processing audio", | |
| "β οΈ Processing Error" | |
| ) | |
| def create_demo_text(): | |
| with gr.Blocks(theme='Nymbo/rounded-gradient', css=".gradio-container {background-color: #F0F8FF}", title="Emotion Detection") as demo: | |
| gr.Markdown("# Text-based bilingual emotion recognition") | |
| with gr.Row(): | |
| text_input = gr.Textbox(label="Write Text") | |
| with gr.Row(): | |
| top_emotion = gr.Markdown("## βοΈ Dominant Emotion: Waiting for input ...", | |
| elem_classes="dominant-emotion") | |
| with gr.Row(): | |
| text_plot = gr.Plot(label="Text Analysis") | |
| text_input.change(fn=plotly_plot_text, inputs=text_input, outputs=[text_plot, top_emotion]) | |
| return demo | |
| def create_demo_audio(): | |
| with gr.Blocks(theme='Nymbo/rounded-gradient', css=".gradio-container {background-color: #F0F8FF}", title="Emotion Detection") as demo: | |
| gr.Markdown("# Text-based bilingual emotion recognition with audio transcription") | |
| with gr.Row(): | |
| audio_input = gr.Audio( | |
| sources=["upload", "microphone"], | |
| type="filepath", | |
| label="Record or Upload Audio", | |
| format="wav", | |
| interactive=True | |
| ) | |
| with gr.Row(): | |
| top_emotion = gr.Markdown("## βοΈ Dominant Emotion: Waiting for input ...", | |
| elem_classes="dominant-emotion") | |
| with gr.Row(): | |
| text_plot = gr.Plot(label="Text Analysis") | |
| transcription = gr.Textbox( | |
| label="π Transcription Results", | |
| placeholder="Transcribed text will appear here...", | |
| lines=3, | |
| max_lines=6 | |
| ) | |
| audio_input.change(fn=plotly_plot_audio, inputs=audio_input, outputs=[text_plot, transcription, top_emotion]) | |
| return demo | |
| def create_demo(): | |
| text = create_demo_text() | |
| audio = create_demo_audio() | |
| demo = gr.TabbedInterface( | |
| [text, audio], | |
| ["Text Prediction", "Transcribed Audio Prediction"], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = Mamba(num_layers = 2, d_input = 1024, d_model = 512, num_classes=7, model_name='jina', pooling=None).to(device) | |
| checkpoint = torch.load("Mamba_jina_checkpoint.pth", map_location=torch.device('cpu')) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| demo = create_demo() | |
| demo.launch() |