ACloudCenter's picture
Add gpu decorator
145285e
raw
history blame
3.31 kB
import gradio as gr
import spaces
import torch
from datasets import load_dataset, Audio
from transformers import MoonshineForConditionalGeneration, AutoProcessor
from datasets import load_dataset, Audio
# Use GPU if available, otherwise fallback to CPU
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Use float16 for faster inference on GPU, float32 on CPU
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Load model from Hugging Face pretrained
model = MoonshineForConditionalGeneration.from_pretrained('UsefulSensors/moonshine-tiny').to(device).to(torch_dtype)
processor = AutoProcessor.from_pretrained('UsefulSensors/moonshine-tiny')
# Function to transcribe audio
@spaces.GPU
def transcribe_audio(audio_file):
# Load audio file
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
dataset = dataset.cast_column("audio", Audio(processor.feature_extractor.sampling_rate))
sample = dataset[0]["audio"]
inputs = processor(
sample["array"],
return_tensors="pt",
sampling_rate=processor.feature_extractor.sampling_rate
)
inputs = inputs.to(device, torch_dtype)
# to avoid hallucination loops, we limit the maximum length of the generated text based expected number of tokens per second
token_limit_factor = 6.5 / processor.feature_extractor.sampling_rate # Maximum of 6.5 tokens per second
seq_lens = inputs.attention_mask.sum(dim=-1)
max_length = int((seq_lens * token_limit_factor).max().item())
generated_ids = model.generate(**inputs, max_length=max_length)
return processor.decode(generated_ids[0], skip_special_tokens=True)
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("## Audio Transcription App")
with gr.Tabs():
with gr.TabItem("Upload Audio"):
audio_file = gr.Audio(
sources=["upload"],
type="filepath",
label="Upload Audio File"
)
output_text1 = gr.Textbox(
label="Transcription",
placeholder="Transcription will appear here..."
)
upload_button = gr.Button("Transcribe Uploaded Audio")
upload_button.click(
fn=transcribe_audio,
inputs=audio_file,
outputs=output_text1
)
with gr.TabItem("Record Audio"):
audio_mic = gr.Audio(
sources=["microphone"],
type="filepath",
label="Record Audio"
)
output_text2 = gr.Textbox(
label="Transcription",
placeholder="Transcription will appear here..."
)
record_button = gr.Button("Transcribe Recorded Audio")
record_button.click(
fn=transcribe_audio,
inputs=audio_mic,
outputs=output_text2
)
gr.Markdown("""
### Instructions:
1. Choose either 'Upload Audio' or 'Record Audio' tab
2. Upload an audio file or record using your microphone
3. Click the respective 'Transcribe' button
4. Wait for the transcription to appear
""")
if __name__ == "__main__":
demo.launch()