Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Commit 
							
							·
						
						a70eba7
	
1
								Parent(s):
							
							c4fe16f
								
fix
Browse files- tools/download_files.py +9 -71
- webui.py +3 -15
    	
        tools/download_files.py
    CHANGED
    
    | @@ -3,30 +3,6 @@ import zipfile | |
| 3 | 
             
            import os
         | 
| 4 | 
             
            import argparse
         | 
| 5 |  | 
| 6 | 
            -
            def download_file_from_google_drive(file_id, destination):
         | 
| 7 | 
            -
                """
         | 
| 8 | 
            -
                通过文件ID下载Google Drive共享文件
         | 
| 9 | 
            -
             | 
| 10 | 
            -
                Args:
         | 
| 11 | 
            -
                    file_id (str): Google Drive文件的ID
         | 
| 12 | 
            -
                    destination (str): 本地保存路径
         | 
| 13 | 
            -
                """
         | 
| 14 | 
            -
                # 基本的下载URL
         | 
| 15 | 
            -
                URL = "https://docs.google.com/uc?export=download"
         | 
| 16 | 
            -
             | 
| 17 | 
            -
                session = requests.Session()
         | 
| 18 | 
            -
             | 
| 19 | 
            -
                # 发起初始GET请求
         | 
| 20 | 
            -
                response = session.get(URL, params={'id': file_id}, stream=True)
         | 
| 21 | 
            -
                token = get_confirm_token(response)  # 从响应中获取确认令牌(如果需要)
         | 
| 22 | 
            -
             | 
| 23 | 
            -
                if token: # 如果需要确认(大文件)
         | 
| 24 | 
            -
                    params = {'id': file_id, 'confirm': token}
         | 
| 25 | 
            -
                    response = session.get(URL, params=params, stream=True)
         | 
| 26 | 
            -
             | 
| 27 | 
            -
                # 将响应内容保存到文件
         | 
| 28 | 
            -
                save_response_content(response, destination)
         | 
| 29 | 
            -
             | 
| 30 | 
             
            def get_confirm_token(response):
         | 
| 31 | 
             
                """
         | 
| 32 | 
             
                从响应中检查是否存在下载确认令牌(cookie)
         | 
| @@ -57,54 +33,27 @@ def save_response_content(response, destination, chunk_size=32768): | |
| 57 | 
             
                            f.write(chunk)
         | 
| 58 |  | 
| 59 | 
             
            def download_model_from_modelscope(destination,hf_cache_dir):
         | 
| 60 | 
            -
                """
         | 
| 61 | 
            -
                从ModelScope下载模型(伪代码,需根据实际API实现)
         | 
| 62 | 
            -
                Args:
         | 
| 63 | 
            -
                    model_id (str): ModelScope模型ID
         | 
| 64 | 
            -
                    destination (str): 本地保存路径
         | 
| 65 | 
            -
                """
         | 
| 66 | 
             
                print(f"[ModelScope] Downloading models to {destination},model cache dir={hf_cache_dir}")
         | 
| 67 | 
             
                from modelscope import snapshot_download
         | 
| 68 | 
            -
                 | 
| 69 | 
            -
                 | 
| 70 | 
            -
                 | 
| 71 | 
            -
                 | 
| 72 | 
            -
                snapshot_download(" | 
| 73 | 
            -
                snapshot_download("amphion/MaskGCT", local_dir="checkpoints/hf_cache/models--amphion--MaskGCT")
         | 
| 74 | 
            -
                snapshot_download("facebook/w2v-bert-2.0",local_dir="checkpoints/hf_cache/models--facebook--w2v-bert-2.0")
         | 
| 75 | 
            -
                snapshot_download("nv-community/bigvgan_v2_22khz_80band_256x",local_dir="checkpoints/hf_cache/models--nvidia--bigvgan_v2_22khz_80band_256x")
         | 
| 76 | 
            -
                # models--funasr--campplus
         | 
| 77 | 
            -
                snapshot_download("nv-community/bigvgan_v2_22khz_80band_256x",local_dir="checkpoints/hf_cache/models--nvidia--bigvgan_v2_22khz_80band_256x")
         | 
| 78 |  | 
| 79 | 
             
            def download_model_from_huggingface(destination,hf_cache_dir):
         | 
| 80 | 
            -
                """
         | 
| 81 | 
            -
                从HuggingFace下载模型(伪代码,需根据实际API实现)
         | 
| 82 | 
            -
                Args:
         | 
| 83 | 
            -
                    model_id (str): HuggingFace模型ID
         | 
| 84 | 
            -
                    destination (str): 本地保存路径
         | 
| 85 | 
            -
                """
         | 
| 86 | 
             
                print(f"[HuggingFace] Downloading models to {destination},model cache dir={hf_cache_dir}")
         | 
| 87 | 
             
                from huggingface_hub import snapshot_download
         | 
| 88 | 
            -
                os.makedirs(os.path.join(hf_cache_dir,"models--amphion--MaskGCT"), exist_ok=True)
         | 
| 89 | 
            -
                os.makedirs(os.path.join(hf_cache_dir,"models--facebook--w2v-bert-2.0"), exist_ok=True)
         | 
| 90 | 
            -
                os.makedirs(os.path.join(hf_cache_dir, "models--nvidia--bigvgan_v2_22khz_80band_256x"), exist_ok=True)
         | 
| 91 | 
            -
                os.makedirs(os.path.join(hf_cache_dir,"models--funasr--campplus"), exist_ok=True)
         | 
| 92 | 
             
                snapshot_download("IndexTeam/IndexTTS-2", local_dir=destination)
         | 
| 93 | 
            -
                 | 
| 94 | 
            -
                 | 
| 95 | 
            -
                # print("[HuggingFace] MaskGCT Download finished")
         | 
| 96 | 
            -
                # snapshot_download("facebook/w2v-bert-2.0",local_dir=os.path.join(hf_cache_dir,"models--facebook--w2v-bert-2.0"))
         | 
| 97 | 
            -
                snapshot_download("facebook/w2v-bert-2.0")
         | 
| 98 | 
            -
                print("[HuggingFace] w2v-bert-2.0 Download finished")
         | 
| 99 | 
             
                snapshot_download("nvidia/bigvgan_v2_22khz_80band_256x",local_dir=os.path.join(hf_cache_dir, "models--nvidia--bigvgan_v2_22khz_80band_256x"))
         | 
| 100 | 
            -
                print("[HuggingFace] bigvgan_v2_22khz_80band_256x Download finished")
         | 
| 101 | 
             
                snapshot_download("funasr/campplus",local_dir=os.path.join(hf_cache_dir,"models--funasr--campplus"))
         | 
| 102 | 
            -
                print("[HuggingFace] campplus Download finished")
         | 
| 103 |  | 
| 104 | 
             
            # 使用示例
         | 
| 105 | 
             
            if __name__ == "__main__":
         | 
| 106 | 
            -
                parser = argparse.ArgumentParser(description=" | 
| 107 | 
            -
                parser.add_argument('--model_source', choices=['modelscope', 'huggingface'], default=None, help=' | 
| 108 | 
             
                args = parser.parse_args()
         | 
| 109 |  | 
| 110 | 
             
                if args.model_source:
         | 
| @@ -112,14 +61,3 @@ if __name__ == "__main__": | |
| 112 | 
             
                        download_model_from_modelscope("checkpoints",os.path.join("checkpoints","hf_cache"))
         | 
| 113 | 
             
                    elif args.model_source == 'huggingface':
         | 
| 114 | 
             
                        download_model_from_huggingface("checkpoints",os.path.join("checkpoints","hf_cache"))
         | 
| 115 | 
            -
             | 
| 116 | 
            -
                print("Downloading example files from Google Drive...")
         | 
| 117 | 
            -
                file_id = "1o_dCMzwjaA2azbGOxAE7-4E7NbJkgdgO"
         | 
| 118 | 
            -
                destination = "example_wavs.zip" # 替换为你希望的本地路径
         | 
| 119 | 
            -
                download_file_from_google_drive(file_id, destination)
         | 
| 120 | 
            -
                print(f"File downloaded to: {destination}")
         | 
| 121 | 
            -
                # 解压下载的zip文件到examples目录
         | 
| 122 | 
            -
                examples_dir = "examples"
         | 
| 123 | 
            -
                with zipfile.ZipFile(destination, 'r') as zip_ref:
         | 
| 124 | 
            -
                    zip_ref.extractall(examples_dir)
         | 
| 125 | 
            -
                print(f"File extracted to: {examples_dir}")
         | 
|  | |
| 3 | 
             
            import os
         | 
| 4 | 
             
            import argparse
         | 
| 5 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 6 | 
             
            def get_confirm_token(response):
         | 
| 7 | 
             
                """
         | 
| 8 | 
             
                从响应中检查是否存在下载确认令牌(cookie)
         | 
|  | |
| 33 | 
             
                            f.write(chunk)
         | 
| 34 |  | 
| 35 | 
             
            def download_model_from_modelscope(destination,hf_cache_dir):
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 36 | 
             
                print(f"[ModelScope] Downloading models to {destination},model cache dir={hf_cache_dir}")
         | 
| 37 | 
             
                from modelscope import snapshot_download
         | 
| 38 | 
            +
                snapshot_download("IndexTeam/IndexTTS-2", local_dir=destination)
         | 
| 39 | 
            +
                snapshot_download("amphion/MaskGCT", local_dir=os.path.join(hf_cache_dir,"models--amphion--MaskGCT"))
         | 
| 40 | 
            +
                snapshot_download("facebook/w2v-bert-2.0",local_dir=os.path.join(hf_cache_dir,"models--facebook--w2v-bert-2.0"))
         | 
| 41 | 
            +
                snapshot_download("nv-community/bigvgan_v2_22khz_80band_256x",local_dir=os.path.join(hf_cache_dir,"models--nvidia--bigvgan_v2_22khz_80band_256x"))
         | 
| 42 | 
            +
                snapshot_download("iic/speech_campplus_sv_zh-cn_16k-common",local_dir=os.path.join(hf_cache_dir,"models--funasr--campplus"))
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 43 |  | 
| 44 | 
             
            def download_model_from_huggingface(destination,hf_cache_dir):
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 45 | 
             
                print(f"[HuggingFace] Downloading models to {destination},model cache dir={hf_cache_dir}")
         | 
| 46 | 
             
                from huggingface_hub import snapshot_download
         | 
|  | |
|  | |
|  | |
|  | |
| 47 | 
             
                snapshot_download("IndexTeam/IndexTTS-2", local_dir=destination)
         | 
| 48 | 
            +
                snapshot_download("amphion/MaskGCT", local_dir=os.path.join(hf_cache_dir,"models--amphion--MaskGCT"))
         | 
| 49 | 
            +
                snapshot_download("facebook/w2v-bert-2.0",local_dir=os.path.join(hf_cache_dir,"models--facebook--w2v-bert-2.0"))
         | 
|  | |
|  | |
|  | |
|  | |
| 50 | 
             
                snapshot_download("nvidia/bigvgan_v2_22khz_80band_256x",local_dir=os.path.join(hf_cache_dir, "models--nvidia--bigvgan_v2_22khz_80band_256x"))
         | 
|  | |
| 51 | 
             
                snapshot_download("funasr/campplus",local_dir=os.path.join(hf_cache_dir,"models--funasr--campplus"))
         | 
|  | |
| 52 |  | 
| 53 | 
             
            # 使用示例
         | 
| 54 | 
             
            if __name__ == "__main__":
         | 
| 55 | 
            +
                parser = argparse.ArgumentParser(description="Download models and example files")
         | 
| 56 | 
            +
                parser.add_argument('-s','--model_source', choices=['modelscope', 'huggingface'], default=None, help='Model source')
         | 
| 57 | 
             
                args = parser.parse_args()
         | 
| 58 |  | 
| 59 | 
             
                if args.model_source:
         | 
|  | |
| 61 | 
             
                        download_model_from_modelscope("checkpoints",os.path.join("checkpoints","hf_cache"))
         | 
| 62 | 
             
                    elif args.model_source == 'huggingface':
         | 
| 63 | 
             
                        download_model_from_huggingface("checkpoints",os.path.join("checkpoints","hf_cache"))
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        webui.py
    CHANGED
    
    | @@ -32,21 +32,9 @@ parser.add_argument("--cuda_kernel", action="store_true", default=False, help="U | |
| 32 | 
             
            parser.add_argument("--gui_seg_tokens", type=int, default=120, help="GUI: Max tokens per generation segment")
         | 
| 33 | 
             
            cmd_args = parser.parse_args()
         | 
| 34 |  | 
| 35 | 
            -
             | 
| 36 | 
            -
             | 
| 37 | 
            -
             | 
| 38 | 
            -
             | 
| 39 | 
            -
            for file in [
         | 
| 40 | 
            -
                "bpe.model",
         | 
| 41 | 
            -
                "gpt.pth",
         | 
| 42 | 
            -
                "config.yaml",
         | 
| 43 | 
            -
                "s2mel.pth",
         | 
| 44 | 
            -
                "wav2vec2bert_stats.pt"
         | 
| 45 | 
            -
            ]:
         | 
| 46 | 
            -
                file_path = os.path.join(cmd_args.model_dir, file)
         | 
| 47 | 
            -
                if not os.path.exists(file_path):
         | 
| 48 | 
            -
                    print(f"Required file {file_path} does not exist. Please download it.")
         | 
| 49 | 
            -
                    sys.exit(1)
         | 
| 50 |  | 
| 51 | 
             
            import gradio as gr
         | 
| 52 | 
             
            from indextts.infer_v2 import IndexTTS2
         | 
|  | |
| 32 | 
             
            parser.add_argument("--gui_seg_tokens", type=int, default=120, help="GUI: Max tokens per generation segment")
         | 
| 33 | 
             
            cmd_args = parser.parse_args()
         | 
| 34 |  | 
| 35 | 
            +
            from tools.download_files import download_model_from_huggingface
         | 
| 36 | 
            +
            download_model_from_huggingface(os.path.join(current_dir,"checkpoints"),
         | 
| 37 | 
            +
                                            os.path.join(current_dir, "checkpoints","hf_cache"))
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 38 |  | 
| 39 | 
             
            import gradio as gr
         | 
| 40 | 
             
            from indextts.infer_v2 import IndexTTS2
         | 
