JotunnBurton commited on
Commit
bc9f98e
·
verified ·
1 Parent(s): ae30800

Update clap_wrapper.py

Browse files
Files changed (1) hide show
  1. clap_wrapper.py +42 -18
clap_wrapper.py CHANGED
@@ -1,34 +1,53 @@
1
  import sys
 
 
2
  import torch
3
- from huggingface_hub import hf_hub_download
4
  from transformers import ClapModel, ClapProcessor
 
5
  from config import config
6
 
7
- models = dict()
8
-
9
- # กำหนดชื่อโมเดลและโฟลเดอร์ที่ต้องการเก็บ
10
- model_name = "laion/clap-htsat-fused"
11
  LOCAL_PATH = "./emotional/clap-htsat-fused"
12
 
13
- # ดาวน์โหลดโมเดลจาก Hugging Face
14
- hf_hub_download(repo_id=model_name,filename="pytorch_model.bin", cache_dir=LOCAL_PATH, force_download=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- # Now load the processor and model from the local directory
17
- processor = ClapProcessor.from_pretrained(LOCAL_PATH)
18
-
19
  def get_clap_audio_feature(audio_data, device=config.bert_gen_config.device):
20
- if sys.platform == "darwin" and torch.backends.mps.is_available() and device == "cpu":
 
 
 
 
21
  device = "mps"
22
  if not device:
23
  device = "cuda"
24
- if device not in models:
25
  if config.webui_config.fp16_run:
26
  models[device] = ClapModel.from_pretrained(
27
- LOCAL_PATH, torch_dtype=torch.float16
28
  ).to(device)
29
  else:
30
  models[device] = ClapModel.from_pretrained(
31
- LOCAL_PATH
32
  ).to(device)
33
  with torch.no_grad():
34
  inputs = processor(
@@ -37,19 +56,24 @@ def get_clap_audio_feature(audio_data, device=config.bert_gen_config.device):
37
  emb = models[device].get_audio_features(**inputs).float()
38
  return emb.T
39
 
 
40
  def get_clap_text_feature(text, device=config.bert_gen_config.device):
41
- if sys.platform == "darwin" and torch.backends.mps.is_available() and device == "cpu":
 
 
 
 
42
  device = "mps"
43
  if not device:
44
  device = "cuda"
45
- if device not in models:
46
  if config.webui_config.fp16_run:
47
  models[device] = ClapModel.from_pretrained(
48
- LOCAL_PATH, torch_dtype=torch.float16
49
  ).to(device)
50
  else:
51
  models[device] = ClapModel.from_pretrained(
52
- LOCAL_PATH
53
  ).to(device)
54
  with torch.no_grad():
55
  inputs = processor(text=text, return_tensors="pt").to(device)
 
1
  import sys
2
+ import os
3
+
4
  import torch
 
5
  from transformers import ClapModel, ClapProcessor
6
+ from huggingface_hub import hf_hub_download
7
  from config import config
8
 
9
+ # กำหนดชื่อและ path ของโมเดล
10
+ HF_REPO_ID = "laion/clap-htsat-fused"
 
 
11
  LOCAL_PATH = "./emotional/clap-htsat-fused"
12
 
13
+ # ตรวจสอบว่ามีไฟล์โมเดลใน LOCAL_PATH แล้วหรือยัง ถ้าไม่มีก็ดาวน์โหลด
14
+ def ensure_model_downloaded():
15
+ os.makedirs(LOCAL_PATH, exist_ok=True)
16
+ required_files = ["pytorch_model.bin", "config.json", "preprocessor_config.json"]
17
+ for file in required_files:
18
+ local_file_path = os.path.join(LOCAL_PATH, file)
19
+ if not os.path.isfile(local_file_path):
20
+ print(f"Downloading {file} from {HF_REPO_ID}...")
21
+ hf_hub_download(
22
+ repo_id=HF_REPO_ID,
23
+ filename=file,
24
+ cache_dir=LOCAL_PATH,
25
+ force_download=False
26
+ )
27
+
28
+ ensure_model_downloaded()
29
+
30
+ # โหลด processor
31
+ models = dict()
32
+ processor = ClapProcessor.from_pretrained(LOCAL_PATH, local_files_only=True)
33
 
 
 
 
34
  def get_clap_audio_feature(audio_data, device=config.bert_gen_config.device):
35
+ if (
36
+ sys.platform == "darwin"
37
+ and torch.backends.mps.is_available()
38
+ and device == "cpu"
39
+ ):
40
  device = "mps"
41
  if not device:
42
  device = "cuda"
43
+ if device not in models.keys():
44
  if config.webui_config.fp16_run:
45
  models[device] = ClapModel.from_pretrained(
46
+ LOCAL_PATH, torch_dtype=torch.float16, local_files_only=True
47
  ).to(device)
48
  else:
49
  models[device] = ClapModel.from_pretrained(
50
+ LOCAL_PATH, local_files_only=True
51
  ).to(device)
52
  with torch.no_grad():
53
  inputs = processor(
 
56
  emb = models[device].get_audio_features(**inputs).float()
57
  return emb.T
58
 
59
+
60
  def get_clap_text_feature(text, device=config.bert_gen_config.device):
61
+ if (
62
+ sys.platform == "darwin"
63
+ and torch.backends.mps.is_available()
64
+ and device == "cpu"
65
+ ):
66
  device = "mps"
67
  if not device:
68
  device = "cuda"
69
+ if device not in models.keys():
70
  if config.webui_config.fp16_run:
71
  models[device] = ClapModel.from_pretrained(
72
+ LOCAL_PATH, torch_dtype=torch.float16, local_files_only=True
73
  ).to(device)
74
  else:
75
  models[device] = ClapModel.from_pretrained(
76
+ LOCAL_PATH, local_files_only=True
77
  ).to(device)
78
  with torch.no_grad():
79
  inputs = processor(text=text, return_tensors="pt").to(device)