ACloudCenter's picture
Modify Title
b3b752d
raw
history blame
3.98 kB
import gradio as gr
import spaces
import torch
import soundfile as sf
import numpy as np
import librosa
import math
from transformers import MoonshineForConditionalGeneration, AutoProcessor
# Use GPU if available and set appropriate dtype
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Load model and processor - Moonshine Tiny
model = MoonshineForConditionalGeneration.from_pretrained('UsefulSensors/moonshine-tiny').to(device).to(torch_dtype)
processor = AutoProcessor.from_pretrained('UsefulSensors/moonshine-tiny')
# Define transcription function using HF Zero GPU
@spaces.GPU
def transcribe_audio(audio_file):
if not audio_file:
return "No audio provided."
# Load and preprocess audio
audio_array, sr = sf.read(audio_file)
if audio_array.ndim > 1:
audio_array = np.mean(audio_array, axis=1)
# Resample if necessary in case the audio file has a different sampling rate
target_sr = processor.feature_extractor.sampling_rate
if sr != target_sr:
audio_array = librosa.resample(audio_array, orig_sr=sr, target_sr=target_sr)
# Prepare inputs for the model - ensure correct dtype and device
inputs = processor(
audio_array,
sampling_rate=target_sr,
return_tensors="pt"
).to(device, torch_dtype)
# Duration-based max_new_tokens calculation
duration_sec = len(audio_array) / float(target_sr)
max_new_tokens = max(24, int(math.ceil(duration_sec * 7.0)))
# Generate transcription with adjusted max_new_tokens
generated_ids = model.generate(**inputs, do_sample=False, max_new_tokens=max_new_tokens)
return processor.decode(generated_ids[0], skip_special_tokens=True) # Decode the generated IDs to text
# Set Gradio theme
theme = gr.themes.Ocean(
primary_hue="indigo",
secondary_hue="fuchsia",
neutral_hue="slate",
).set(
button_large_radius='*radius_sm'
)
# Create Gradio interface
with gr.Blocks(theme=theme) as demo:
gr.Markdown("## Moonshine Tiny STT - 27M Parameters")
gr.HTML("""
<div style="width: 100%; margin-bottom: 20px;">
<img src="https://huggingface.co/spaces/ACloudCenter/moonshine-tiny-STT/resolve/main/public/images/banner.png"
style="width: 100%; height: auto; border-radius: 15px; box-shadow: 0 10px 40px rgba(0,0,0,0.2);"
alt="VibeVoice Banner">
</div>
""")
with gr.Tabs():
with gr.TabItem("Upload Audio"):
audio_file = gr.Audio(
sources=["upload"],
type="filepath",
label="Upload Audio File"
)
output_text1 = gr.Textbox(
label="Transcription",
placeholder="Transcription will appear here..."
)
upload_button = gr.Button("Transcribe Uploaded Audio")
upload_button.click(
fn=transcribe_audio,
inputs=audio_file,
outputs=output_text1
)
with gr.TabItem("Record Audio"):
audio_mic = gr.Audio(
sources=["microphone"],
type="filepath",
label="Record Audio"
)
output_text2 = gr.Textbox(
label="Transcription",
placeholder="Transcription will appear here..."
)
record_button = gr.Button("Transcribe Recorded Audio")
record_button.click(
fn=transcribe_audio,
inputs=audio_mic,
outputs=output_text2
)
gr.Markdown("""
### Instructions:
1. Choose either 'Upload Audio' or 'Record Audio' tab
2. Upload an audio file or record using your microphone
3. Click the respective 'Transcribe' button
4. Wait for the transcription to appear
""")
if __name__ == "__main__":
demo.launch()