Update README.md
Browse files
README.md
CHANGED
|
@@ -64,3 +64,38 @@ manually evaluate WER on test set - vietnamese part:
|
|
| 64 |
| this LoRA | 14.7% | 14.7% | 9.4% |
|
| 65 |
|
| 66 |
all training + evaluation scripts are on my repo: https://github.com/phineas-pta/fine-tune-whisper-vi
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
| this LoRA | 14.7% | 14.7% | 9.4% |
|
| 65 |
|
| 66 |
all training + evaluation scripts are on my repo: https://github.com/phineas-pta/fine-tune-whisper-vi
|
| 67 |
+
|
| 68 |
+
usage example:
|
| 69 |
+
```python
|
| 70 |
+
# pip install peft accelerate bitsandbytes
|
| 71 |
+
import torch
|
| 72 |
+
import torchaudio
|
| 73 |
+
from peft import PeftModel, PeftConfig
|
| 74 |
+
from transformers import WhisperForConditionalGeneration, WhisperFeatureExtractor, WhisperTokenizer
|
| 75 |
+
|
| 76 |
+
PEFT_MODEL_ID = "doof-ferb/whisper-large-peft-lora-vi"
|
| 77 |
+
BASE_MODEL_ID = PeftConfig.from_pretrained(PEFT_MODEL_ID).base_model_name_or_path
|
| 78 |
+
|
| 79 |
+
FEATURE_EXTRACTOR = WhisperFeatureExtractor.from_pretrained(BASE_MODEL_ID)
|
| 80 |
+
TOKENIZER = WhisperTokenizer.from_pretrained(BASE_MODEL_ID)
|
| 81 |
+
|
| 82 |
+
MODEL = PeftModel.from_pretrained(
|
| 83 |
+
WhisperForConditionalGeneration.from_pretrained(BASE_MODEL_ID, torch_dtype=torch.float16).to("cuda:0"),
|
| 84 |
+
PEFT_MODEL_ID
|
| 85 |
+
).merge_and_unload(progressbar=True)
|
| 86 |
+
|
| 87 |
+
DECODER_ID = torch.tensor(
|
| 88 |
+
TOKENIZER.convert_tokens_to_ids(["<|startoftranscript|>", "<|vi|>", "<|transcribe|>", "<|notimestamps|>"]),
|
| 89 |
+
device=MODEL.device
|
| 90 |
+
).unsqueeze(dim=0)
|
| 91 |
+
|
| 92 |
+
waveform, sampling_rate = torchaudio.load("audio.mp3")
|
| 93 |
+
if waveform.size(0) > 1: # convert dual to mono channel
|
| 94 |
+
waveform = waveform.mean(dim=0, keepdim=True)
|
| 95 |
+
|
| 96 |
+
inputs = FEATURE_EXTRACTOR(waveform, sampling_rate=sampling_rate, return_tensors="pt").to(MODEL.device)
|
| 97 |
+
with torch.inference_mode(), torch.autocast(device_type="cuda"): # required by PEFT
|
| 98 |
+
predicted_ids = MODEL.generate(input_features=inputs.input_features, decoder_input_ids=DECODER_ID)
|
| 99 |
+
|
| 100 |
+
TOKENIZER.batch_decode(predicted_ids, skip_special_tokens=True)[0]
|
| 101 |
+
```
|