Swahili-ASR / app.py
badrex's picture
Update app.py
df23ecf verified
raw
history blame
3.56 kB
import gradio as gr
from transformers import pipeline
import numpy as np
import os
import torch
import torchaudio
from huggingface_hub import login
import librosa
import spaces
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
#MODEL_ID = "badrex/w2v-bert-2.0-swahili-asr"
#transcriber = pipeline("automatic-speech-recognition", model=MODEL_ID)
# Load model and processor
MODEL_PATH = "badrex/w2v-bert-2.0-swahili-asr"
processor = AutoProcessor.from_pretrained(MODEL_PATH)
model = AutoModelForCTC.from_pretrained(MODEL_PATH)
# move model and processor to device
model = model.to(device)
#processor = processor.to(device)
@spaces.GPU()
def transcribe(audio_path):
"""Process audio with return the generated respotextnse.
Args:
audio_path: Path to the audio file to be transcribed.
Returns:
String containing the transcribed text from the audio file, or an error message
if the audio file is missing.
"""
if not audio_path:
return "Please upload an audio file."
# get audio array
audio_array, sample_rate = torchaudio.load(audio_path)
# if sample rate is not 16000, resample to 16000
if sample_rate != 16000:
audio_array = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio_array)
#audio_array = audio_array.to(device)
inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
#inputs = inputs.to(device, dtype=torch.bfloat16)
with torch.no_grad():
logits = model(**inputs).logits
outputs = torch.argmax(logits, dim=-1)
decoded_outputs = processor.batch_decode(
outputs,
skip_special_tokens=True
)
return decoded_outputs[0].strip()
examples = []
examples_dir = "examples"
if os.path.exists(examples_dir):
for filename in os.listdir(examples_dir):
if filename.endswith((".wav", ".mp3", ".ogg")):
examples.append([os.path.join(examples_dir, filename)])
print(f"Found {len(examples)} example files")
else:
print("Examples directory not found")
demo = gr.Interface(
fn=transcribe,
inputs=gr.Audio(),
outputs="text",
title="<div>Swahili-ASR ๐ŸŽ™๏ธ <br>Speech Recognition for Swahili Language</div>",
description="""
<div class="centered-content">
<div>
<p>
Developed with โค by <a href="https://badrex.github.io/" style="color: #2563eb;">Badr al-Absi</a> โ˜•
</p>
<br>
<p style="font-size: 15px; line-height: 1.8;">
Hi there ๐Ÿ‘‹๐Ÿผ
<br>
<br>
This is a demo for <a href="https://huggingface.co/badrex/w2v-bert-2.0-swahili-asr" style="color: #2563eb;"> badrex/w2v-bert-2.0-swahili-asr</a>, a robust Transformer-based automatic speech recognition (ASR) system for Swahili language.
The underlying ASR model was trained on more than 400 hours of transcribed speech.
<br>
<p style="font-size: 15px; line-height: 1.8;">
Simply <strong>upload an audio file</strong> ๐Ÿ“ค or <strong>record yourself speaking</strong> ๐ŸŽ™๏ธโบ๏ธ to try out the model!
</p>
</div>
</div>
""",
examples=examples if examples else None,
cache_examples=False,
flagging_mode=None,
)
if __name__ == "__main__":
demo.launch()