Update export_models.sh
Browse files- export_models.sh +13 -13
    	
        export_models.sh
    CHANGED
    
    | @@ -6,22 +6,22 @@ python << END | |
| 6 | 
             
            from transformers import WhisperForConditionalGeneration, TFWhisperForConditionalGeneration, WhisperTokenizerFast
         | 
| 7 | 
             
            import shutil
         | 
| 8 |  | 
| 9 | 
            -
            # Backup generation_config.json
         | 
| 10 | 
            -
            shutil.copyfile('./generation_config.json', './generation_config_backup.json')
         | 
| 11 |  | 
| 12 | 
             
            print("Saving model to PyTorch...", end=" ")
         | 
| 13 | 
             
            model = WhisperForConditionalGeneration.from_pretrained("./", from_flax=True)
         | 
| 14 | 
             
            model.save_pretrained("./", safe_serialization=True)
         | 
| 15 | 
            -
            model.save_pretrained("./", safe_serialization=False)
         | 
| 16 | 
             
            print("Done.")
         | 
| 17 |  | 
| 18 | 
            -
            print("Saving model to TensorFlow...", end=" ")
         | 
| 19 | 
            -
            tf_model = TFWhisperForConditionalGeneration.from_pretrained("./")
         | 
| 20 | 
            -
            tf_model.save_pretrained("./")
         | 
| 21 | 
            -
            print("Done.")
         | 
| 22 |  | 
| 23 | 
             
            # Restore the backup of generation_config.json
         | 
| 24 | 
            -
            shutil.move('./generation_config_backup.json', './generation_config.json')
         | 
| 25 |  | 
| 26 | 
             
            print("Saving model to ONNX...", end=" ")
         | 
| 27 | 
             
            from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
         | 
| @@ -29,16 +29,16 @@ ort_model = ORTModelForSpeechSeq2Seq.from_pretrained("./", export=True) | |
| 29 | 
             
            ort_model.save_pretrained("./onnx")
         | 
| 30 | 
             
            print("Done")
         | 
| 31 |  | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
|  | |
|  | |
| 34 | 
             
            cp ct2/model.bin .
         | 
| 35 | 
             
            cp ct2/vocabulary.json .
         | 
| 36 | 
             
            cp config.json config_hf.json
         | 
| 37 | 
             
            jq -s '.[0] * .[1]' ct2/config.json config_hf.json > config.json
         | 
| 38 | 
            -
             | 
| 39 | 
            -
             | 
| 40 |  | 
| 41 | 
            -
            END
         | 
| 42 |  | 
| 43 | 
             
            echo "Saving model to GGML (whisper.cpp)..."
         | 
| 44 | 
             
            wget -O convert-h5-to-ggml.py "https://raw.githubusercontent.com/NbAiLab/nb-whisper/main/convert-h5-to-ggml.py"
         | 
|  | |
| 6 | 
             
            from transformers import WhisperForConditionalGeneration, TFWhisperForConditionalGeneration, WhisperTokenizerFast
         | 
| 7 | 
             
            import shutil
         | 
| 8 |  | 
| 9 | 
            +
            # Backup generation_config.json - this is for tensorflow only, but at the moment that is causing errors.
         | 
| 10 | 
            +
            # shutil.copyfile('./generation_config.json', './generation_config_backup.json')
         | 
| 11 |  | 
| 12 | 
             
            print("Saving model to PyTorch...", end=" ")
         | 
| 13 | 
             
            model = WhisperForConditionalGeneration.from_pretrained("./", from_flax=True)
         | 
| 14 | 
             
            model.save_pretrained("./", safe_serialization=True)
         | 
| 15 | 
            +
            model.save_pretrained("./", safe_serialization=False, max_shard_size="10000MB")
         | 
| 16 | 
             
            print("Done.")
         | 
| 17 |  | 
| 18 | 
            +
            #print("Saving model to TensorFlow...", end=" ")
         | 
| 19 | 
            +
            #tf_model = TFWhisperForConditionalGeneration.from_pretrained("./")
         | 
| 20 | 
            +
            #tf_model.save_pretrained("./")
         | 
| 21 | 
            +
            #print("Done.")
         | 
| 22 |  | 
| 23 | 
             
            # Restore the backup of generation_config.json
         | 
| 24 | 
            +
            #shutil.move('./generation_config_backup.json', './generation_config.json')
         | 
| 25 |  | 
| 26 | 
             
            print("Saving model to ONNX...", end=" ")
         | 
| 27 | 
             
            from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
         | 
|  | |
| 29 | 
             
            ort_model.save_pretrained("./onnx")
         | 
| 30 | 
             
            print("Done")
         | 
| 31 |  | 
| 32 | 
            +
            END
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            echo "Saving model to CTranslate..."
         | 
| 35 | 
            +
            ct2-transformers-converter --model . --output_dir ct2 --force
         | 
| 36 | 
             
            cp ct2/model.bin .
         | 
| 37 | 
             
            cp ct2/vocabulary.json .
         | 
| 38 | 
             
            cp config.json config_hf.json
         | 
| 39 | 
             
            jq -s '.[0] * .[1]' ct2/config.json config_hf.json > config.json
         | 
| 40 | 
            +
            echo "Done"
         | 
|  | |
| 41 |  | 
|  | |
| 42 |  | 
| 43 | 
             
            echo "Saving model to GGML (whisper.cpp)..."
         | 
| 44 | 
             
            wget -O convert-h5-to-ggml.py "https://raw.githubusercontent.com/NbAiLab/nb-whisper/main/convert-h5-to-ggml.py"
         | 

