Spaces:
Running
on
Zero
Running
on
Zero
PRamoneda
commited on
Commit
·
df703c7
1
Parent(s):
2e9908b
cpu
Browse files- get_difficulty.py +6 -6
- model.py +5 -5
get_difficulty.py
CHANGED
|
@@ -31,14 +31,14 @@ def get_cqt_from_mp3(mp3_path):
|
|
| 31 |
log_cqt = librosa.amplitude_to_db(np.abs(cqt))
|
| 32 |
log_cqt = log_cqt.T # shape (T, 88)
|
| 33 |
log_cqt = downsample_log_cqt(log_cqt, target_fs=5)
|
| 34 |
-
cqt_tensor = torch.tensor(log_cqt, dtype=torch.float32).unsqueeze(0).unsqueeze(0).
|
| 35 |
# pdb.set_trace()
|
| 36 |
print(f"cqt shape: {log_cqt.shape}")
|
| 37 |
return cqt_tensor
|
| 38 |
|
| 39 |
def get_pianoroll_from_mp3(mp3_path):
|
| 40 |
audio, _ = load_audio(mp3_path, sr=sample_rate, mono=True)
|
| 41 |
-
transcriptor = PianoTranscription(device='
|
| 42 |
midi_path = "temp.mid"
|
| 43 |
transcriptor.transcribe(audio, midi_path)
|
| 44 |
midi_data = pretty_midi.PrettyMIDI(midi_path)
|
|
@@ -57,8 +57,8 @@ def get_pianoroll_from_mp3(mp3_path):
|
|
| 57 |
if 0 <= pitch < 88 and onset_frame < time_steps:
|
| 58 |
onsets[onset_frame, pitch] = 1.0
|
| 59 |
|
| 60 |
-
pr_tensor = torch.tensor(piano_roll.T).unsqueeze(0).unsqueeze(1).
|
| 61 |
-
on_tensor = torch.tensor(onsets.T).unsqueeze(0).unsqueeze(1).
|
| 62 |
out_tensor = torch.cat([pr_tensor, on_tensor], dim=1)
|
| 63 |
print(f"piano_roll shape: {out_tensor.shape}")
|
| 64 |
return out_tensor.transpose(2, 3)
|
|
@@ -75,7 +75,7 @@ def predict_difficulty(mp3_path, model_name, rep):
|
|
| 75 |
rep_clean = rep
|
| 76 |
|
| 77 |
model = AudioModel(num_classes=11, rep=rep_clean, modality_dropout=False, only_cqt=only_cqt, only_pr=only_pr)
|
| 78 |
-
checkpoint = [torch.load(f"models/{model_name}/checkpoint_{i}.pth", map_location="
|
| 79 |
for i in range(5)]
|
| 80 |
|
| 81 |
|
|
@@ -93,7 +93,7 @@ def predict_difficulty(mp3_path, model_name, rep):
|
|
| 93 |
preds = []
|
| 94 |
for cheks in checkpoint:
|
| 95 |
model.load_state_dict(cheks["model_state_dict"])
|
| 96 |
-
model = model.
|
| 97 |
with torch.inference_mode():
|
| 98 |
logits = model(inp_data, None)
|
| 99 |
pred = prediction2label(logits).item()
|
|
|
|
| 31 |
log_cqt = librosa.amplitude_to_db(np.abs(cqt))
|
| 32 |
log_cqt = log_cqt.T # shape (T, 88)
|
| 33 |
log_cqt = downsample_log_cqt(log_cqt, target_fs=5)
|
| 34 |
+
cqt_tensor = torch.tensor(log_cqt, dtype=torch.float32).unsqueeze(0).unsqueeze(0).cpu()
|
| 35 |
# pdb.set_trace()
|
| 36 |
print(f"cqt shape: {log_cqt.shape}")
|
| 37 |
return cqt_tensor
|
| 38 |
|
| 39 |
def get_pianoroll_from_mp3(mp3_path):
|
| 40 |
audio, _ = load_audio(mp3_path, sr=sample_rate, mono=True)
|
| 41 |
+
transcriptor = PianoTranscription(device='cpu')
|
| 42 |
midi_path = "temp.mid"
|
| 43 |
transcriptor.transcribe(audio, midi_path)
|
| 44 |
midi_data = pretty_midi.PrettyMIDI(midi_path)
|
|
|
|
| 57 |
if 0 <= pitch < 88 and onset_frame < time_steps:
|
| 58 |
onsets[onset_frame, pitch] = 1.0
|
| 59 |
|
| 60 |
+
pr_tensor = torch.tensor(piano_roll.T).unsqueeze(0).unsqueeze(1).cpu().float()
|
| 61 |
+
on_tensor = torch.tensor(onsets.T).unsqueeze(0).unsqueeze(1).cpu().float()
|
| 62 |
out_tensor = torch.cat([pr_tensor, on_tensor], dim=1)
|
| 63 |
print(f"piano_roll shape: {out_tensor.shape}")
|
| 64 |
return out_tensor.transpose(2, 3)
|
|
|
|
| 75 |
rep_clean = rep
|
| 76 |
|
| 77 |
model = AudioModel(num_classes=11, rep=rep_clean, modality_dropout=False, only_cqt=only_cqt, only_pr=only_pr)
|
| 78 |
+
checkpoint = [torch.load(f"models/{model_name}/checkpoint_{i}.pth", map_location="cpu", weights_only=False)
|
| 79 |
for i in range(5)]
|
| 80 |
|
| 81 |
|
|
|
|
| 93 |
preds = []
|
| 94 |
for cheks in checkpoint:
|
| 95 |
model.load_state_dict(cheks["model_state_dict"])
|
| 96 |
+
model = model.cpu().eval()
|
| 97 |
with torch.inference_mode():
|
| 98 |
logits = model(inp_data, None)
|
| 99 |
pred = prediction2label(logits).item()
|
model.py
CHANGED
|
@@ -222,7 +222,7 @@ def get_mse_macro(y_true, y_pred):
|
|
| 222 |
|
| 223 |
def get_cqt(rep, k):
|
| 224 |
inp_data = utils.load_binary(f"../videos_download/{rep}/{k}.bin")
|
| 225 |
-
inp_data = torch.tensor(inp_data, dtype=torch.float32).
|
| 226 |
inp_data = inp_data.unsqueeze(0).unsqueeze(0).transpose(2, 3)
|
| 227 |
return inp_data
|
| 228 |
|
|
@@ -230,8 +230,8 @@ def get_cqt(rep, k):
|
|
| 230 |
def get_pianoroll(rep, k):
|
| 231 |
inp_pr = utils.load_binary(f"../videos_download/{rep}/{k}.bin")
|
| 232 |
inp_on = utils.load_binary(f"../videos_download/{rep}/{k}_onset.bin")
|
| 233 |
-
inp_pr = torch.from_numpy(inp_pr).float().
|
| 234 |
-
inp_on = torch.from_numpy(inp_on).float().
|
| 235 |
inp_data = torch.stack([inp_pr, inp_on], dim=1)
|
| 236 |
inp_data = inp_data.unsqueeze(0).permute(0, 1, 2, 3)
|
| 237 |
return inp_data
|
|
@@ -255,12 +255,12 @@ def compute_model_basic(model_name, rep, modality_dropout, only_cqt=False, only_
|
|
| 255 |
for split in range(5):
|
| 256 |
#load_model
|
| 257 |
model = AudioModel(11, rep, modality_dropout, only_cqt, only_pr)
|
| 258 |
-
checkpoint = torch.load(f"models/{model_name}/checkpoint_{split}.pth", map_location='
|
| 259 |
# print(checkpoint["epoch"])
|
| 260 |
# print(checkpoint.keys())
|
| 261 |
|
| 262 |
model.load_state_dict(checkpoint['model_state_dict'])
|
| 263 |
-
model = model.
|
| 264 |
pred_labels, true_labels = [], []
|
| 265 |
predictions_split = {}
|
| 266 |
model.eval()
|
|
|
|
| 222 |
|
| 223 |
def get_cqt(rep, k):
|
| 224 |
inp_data = utils.load_binary(f"../videos_download/{rep}/{k}.bin")
|
| 225 |
+
inp_data = torch.tensor(inp_data, dtype=torch.float32).cpu()
|
| 226 |
inp_data = inp_data.unsqueeze(0).unsqueeze(0).transpose(2, 3)
|
| 227 |
return inp_data
|
| 228 |
|
|
|
|
| 230 |
def get_pianoroll(rep, k):
|
| 231 |
inp_pr = utils.load_binary(f"../videos_download/{rep}/{k}.bin")
|
| 232 |
inp_on = utils.load_binary(f"../videos_download/{rep}/{k}_onset.bin")
|
| 233 |
+
inp_pr = torch.from_numpy(inp_pr).float().cpu()
|
| 234 |
+
inp_on = torch.from_numpy(inp_on).float().cpu()
|
| 235 |
inp_data = torch.stack([inp_pr, inp_on], dim=1)
|
| 236 |
inp_data = inp_data.unsqueeze(0).permute(0, 1, 2, 3)
|
| 237 |
return inp_data
|
|
|
|
| 255 |
for split in range(5):
|
| 256 |
#load_model
|
| 257 |
model = AudioModel(11, rep, modality_dropout, only_cqt, only_pr)
|
| 258 |
+
checkpoint = torch.load(f"models/{model_name}/checkpoint_{split}.pth", map_location='cpu')
|
| 259 |
# print(checkpoint["epoch"])
|
| 260 |
# print(checkpoint.keys())
|
| 261 |
|
| 262 |
model.load_state_dict(checkpoint['model_state_dict'])
|
| 263 |
+
model = model.cpu()
|
| 264 |
pred_labels, true_labels = [], []
|
| 265 |
predictions_split = {}
|
| 266 |
model.eval()
|