| import argparse | |
| from transformers import AutoProcessor | |
| from transformers import Wav2Vec2ProcessorWithLM | |
| from pyctcdecode import build_ctcdecoder | |
| def main(args): | |
| processor = AutoProcessor.from_pretrained(args.model_name_or_path) | |
| vocab_dict = processor.tokenizer.get_vocab() | |
| sorted_vocab_dict = { | |
| k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1]) | |
| } | |
| decoder = build_ctcdecoder( | |
| labels=list(sorted_vocab_dict.keys()), | |
| kenlm_model_path=args.kenlm_model_path, | |
| ) | |
| processor_with_lm = Wav2Vec2ProcessorWithLM( | |
| feature_extractor=processor.feature_extractor, | |
| tokenizer=processor.tokenizer, | |
| decoder=decoder, | |
| ) | |
| processor_with_lm.save_pretrained(args.model_name_or_path) | |
| print(f"Run: ~/bin/build_binary language_model/*.arpa language_model/5gram.bin -T $(pwd) && rm language_model/*.arpa") | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--model_name_or_path', default="./", help='Model name or path. Defaults to ./') | |
| parser.add_argument('--kenlm_model_path', required=True, help='Path to KenLM arpa file.') | |
| args = parser.parse_args() | |
| return args | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| main(args) | |