Add Guidance for Repetition Penalty
Browse fileshttps://github.com/huggingface/transformers/pull/37625 added support for excluding the input tokens from RepetitionPenaltyLogitsProcessor - this updates the code snippet to do this with a repetition penalty of 3.
README.md
CHANGED
|
@@ -51,7 +51,7 @@ Then run the code:
|
|
| 51 |
```python
|
| 52 |
import torch
|
| 53 |
import torchaudio
|
| 54 |
-
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
|
| 55 |
from huggingface_hub import hf_hub_download
|
| 56 |
|
| 57 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
@@ -64,7 +64,6 @@ speech_granite = AutoModelForSpeechSeq2Seq.from_pretrained(
|
|
| 64 |
model_name).to(device)
|
| 65 |
|
| 66 |
# prepare speech and text prompt, using the appropriate prompt template
|
| 67 |
-
|
| 68 |
audio_path = hf_hub_download(repo_id=model_name, filename='10226_10111_000000.wav')
|
| 69 |
wav, sr = torchaudio.load(audio_path, normalize=True)
|
| 70 |
assert wav.shape[0] == 1 and sr == 16000 # mono, 16khz
|
|
@@ -92,7 +91,14 @@ model_inputs = speech_granite_processor(
|
|
| 92 |
device=device, # Computation device; returned tensors are put on CPU
|
| 93 |
return_tensors="pt",
|
| 94 |
).to(device)
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
model_outputs = speech_granite.generate(
|
| 97 |
**model_inputs,
|
| 98 |
max_new_tokens=200,
|
|
@@ -100,9 +106,9 @@ model_outputs = speech_granite.generate(
|
|
| 100 |
do_sample=False,
|
| 101 |
min_length=1,
|
| 102 |
top_p=1.0,
|
| 103 |
-
repetition_penalty=1.0,
|
| 104 |
length_penalty=1.0,
|
| 105 |
temperature=1.0,
|
|
|
|
| 106 |
bos_token_id=tokenizer.bos_token_id,
|
| 107 |
eos_token_id=tokenizer.eos_token_id,
|
| 108 |
pad_token_id=tokenizer.pad_token_id,
|
|
|
|
| 51 |
```python
|
| 52 |
import torch
|
| 53 |
import torchaudio
|
| 54 |
+
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, RepetitionPenaltyLogitsProcessor
|
| 55 |
from huggingface_hub import hf_hub_download
|
| 56 |
|
| 57 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 64 |
model_name).to(device)
|
| 65 |
|
| 66 |
# prepare speech and text prompt, using the appropriate prompt template
|
|
|
|
| 67 |
audio_path = hf_hub_download(repo_id=model_name, filename='10226_10111_000000.wav')
|
| 68 |
wav, sr = torchaudio.load(audio_path, normalize=True)
|
| 69 |
assert wav.shape[0] == 1 and sr == 16000 # mono, 16khz
|
|
|
|
| 91 |
device=device, # Computation device; returned tensors are put on CPU
|
| 92 |
return_tensors="pt",
|
| 93 |
).to(device)
|
| 94 |
+
|
| 95 |
+
# The recommended repetition penalty is 3 as long as input IDs are excluded.
|
| 96 |
+
# Otherwise, you should use a reptition penalty of 1 to keep results stable.
|
| 97 |
+
reptition_penalty_processor = RepetitionPenaltyLogitsProcessor(
|
| 98 |
+
penalty=3.0,
|
| 99 |
+
prompt_ignore_length=model_inputs["input_ids"].shape[-1],
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
model_outputs = speech_granite.generate(
|
| 103 |
**model_inputs,
|
| 104 |
max_new_tokens=200,
|
|
|
|
| 106 |
do_sample=False,
|
| 107 |
min_length=1,
|
| 108 |
top_p=1.0,
|
|
|
|
| 109 |
length_penalty=1.0,
|
| 110 |
temperature=1.0,
|
| 111 |
+
logits_processor=[reptition_penalty_processor],
|
| 112 |
bos_token_id=tokenizer.bos_token_id,
|
| 113 |
eos_token_id=tokenizer.eos_token_id,
|
| 114 |
pad_token_id=tokenizer.pad_token_id,
|