|
|
from datetime import datetime |
|
|
import gradio as gr |
|
|
import torch |
|
|
import torchaudio |
|
|
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq |
|
|
import spaces |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model_name = "ibm-granite/granite-speech-3.3-8b" |
|
|
|
|
|
processor = AutoProcessor.from_pretrained(model_name) |
|
|
tokenizer = processor.tokenizer |
|
|
model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
|
model_name, device_map=device, torch_dtype=torch.bfloat16 |
|
|
) |
|
|
|
|
|
|
|
|
def _load_audio_mono_16k(file_path: str) -> torch.Tensor: |
|
|
wav, sr = torchaudio.load(file_path, normalize=True) |
|
|
if wav.shape[0] > 1: |
|
|
wav = torch.mean(wav, dim=0, keepdim=True) |
|
|
if sr != 16000: |
|
|
wav = torchaudio.functional.resample(wav, sr, 16000) |
|
|
return wav |
|
|
|
|
|
@spaces.GPU |
|
|
def process_audio(audio_path: str, instruction: str) -> str: |
|
|
if not audio_path: |
|
|
return "Please upload an audio file." |
|
|
|
|
|
wav = _load_audio_mono_16k(audio_path) |
|
|
|
|
|
date_string = datetime.now().strftime("%B %d, %Y") |
|
|
|
|
|
system_prompt = ( |
|
|
"Knowledge Cutoff Date: April 2024.\n" |
|
|
f"Today's Date: {date_string}.\n" |
|
|
"You are Granite, developed by IBM. You are a helpful AI assistant" |
|
|
) |
|
|
user_prompt = f"<|audio|>{instruction.strip()}" |
|
|
chat = [ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": user_prompt}, |
|
|
] |
|
|
prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) |
|
|
|
|
|
model_inputs = processor(prompt, wav, device=device, return_tensors="pt").to(device) |
|
|
outputs = model.generate( |
|
|
**model_inputs, |
|
|
max_new_tokens=4096, |
|
|
do_sample=False, |
|
|
num_beams=1, |
|
|
) |
|
|
|
|
|
num_input_tokens = model_inputs["input_ids"].shape[-1] |
|
|
new_tokens = torch.unsqueeze(outputs[0, num_input_tokens:], dim=0) |
|
|
text = tokenizer.batch_decode(new_tokens, add_special_tokens=False, skip_special_tokens=True)[0] |
|
|
return text |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Granite Speech Demo") as demo: |
|
|
gr.Markdown("# Granite Speech-to-Text Demo") |
|
|
gr.Markdown("Upload audio and transcribe with IBM Granite.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
audio_input = gr.Audio(type="filepath", label="Upload Audio") |
|
|
instruction = gr.Textbox( |
|
|
label="Instruction", |
|
|
value="can you transcribe the speech into a written format?", |
|
|
) |
|
|
submit_btn = gr.Button("Transcribe", variant="primary") |
|
|
with gr.Column(): |
|
|
output_text = gr.Textbox(label="Output", lines=12) |
|
|
|
|
|
submit_btn.click(process_audio, [audio_input, instruction], output_text) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue().launch(share=False, ssr_mode=False) |
|
|
|
|
|
|
|
|
|