audio-difficulty / get_difficulty.py
PRamoneda
gpu to cpu
a5af45b
raw
history blame
4.01 kB
import os
import pdb
from statistics import mean
import torch
from torch import nn
import numpy as np
import librosa
from piano_transcription_inference import PianoTranscription, sample_rate, load_audio
import pretty_midi
from utils import prediction2label
from model import AudioModel
from scipy.signal import resample
def downsample_log_cqt(cqt_matrix, target_fs=5):
original_fs = 44100 / 160
ratio = original_fs / target_fs
downsampled = resample(cqt_matrix, int(cqt_matrix.shape[0] / ratio), axis=0)
return downsampled
def downsample_matrix(mat, original_fs, target_fs):
ratio = original_fs / target_fs
return resample(mat, int(mat.shape[0] / ratio), axis=0)
def get_cqt_from_mp3(mp3_path):
sample_rate = 44100
hop_length = 160
y, sr = librosa.load(mp3_path, sr=sample_rate, mono=True)
cqt = librosa.cqt(y, sr=sr, hop_length=hop_length, n_bins=88, bins_per_octave=12)
log_cqt = librosa.amplitude_to_db(np.abs(cqt))
log_cqt = log_cqt.T # shape (T, 88)
log_cqt = downsample_log_cqt(log_cqt, target_fs=5)
cqt_tensor = torch.tensor(log_cqt, dtype=torch.float32).unsqueeze(0).unsqueeze(0).cpu()
print(f"cqt shape: {log_cqt.shape}")
return cqt_tensor
def get_pianoroll_from_mp3(mp3_path):
audio, _ = load_audio(mp3_path, sr=sample_rate, mono=True)
transcriptor = PianoTranscription(device="cuda" if torch.cuda.is_available() else "cpu")
midi_path = "temp.mid"
transcriptor.transcribe(audio, midi_path)
midi_data = pretty_midi.PrettyMIDI(midi_path)
fs = 5 # original frames per second
piano_roll = midi_data.get_piano_roll(fs=fs)[21:109].T # shape: (T, 88)
piano_roll = piano_roll / 127
time_steps = piano_roll.shape[0]
onsets = np.zeros_like(piano_roll)
for instrument in midi_data.instruments:
for note in instrument.notes:
pitch = note.pitch - 21
onset_frame = int(note.start * fs)
if 0 <= pitch < 88 and onset_frame < time_steps:
onsets[onset_frame, pitch] = 1.0
pr_tensor = torch.tensor(piano_roll.T).unsqueeze(0).unsqueeze(1).cpu().float()
on_tensor = torch.tensor(onsets.T).unsqueeze(0).unsqueeze(1).cpu().float()
out_tensor = torch.cat([pr_tensor, on_tensor], dim=1)
print(f"piano_roll shape: {out_tensor.shape}")
return out_tensor.transpose(2, 3)
def predict_difficulty(mp3_path, model_name, rep):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if "only_cqt" in rep:
only_cqt, only_pr = True, False
rep_clean = "multimodal5"
elif "only_pr" in rep:
only_cqt, only_pr = False, True
rep_clean = "multimodal5"
else:
only_cqt = only_pr = False
rep_clean = rep
model = AudioModel(num_classes=11, rep=rep_clean, modality_dropout=False, only_cqt=only_cqt, only_pr=only_pr).to(device)
checkpoint = [torch.load(f"models/{model_name}/checkpoint_{i}.pth", map_location=device, weights_only=False)
for i in range(5)]
if rep == "cqt5":
inp_data = get_cqt_from_mp3(mp3_path).to(device)
elif rep == "pianoroll5":
inp_data = get_pianoroll_from_mp3(mp3_path).to(device)
elif rep_clean == "multimodal5":
x1 = get_pianoroll_from_mp3(mp3_path).to(device)
x2 = get_cqt_from_mp3(mp3_path).to(device)
inp_data = [x1, x2]
else:
raise ValueError(f"Representation {rep} not supported")
preds = []
for cheks in checkpoint:
model.load_state_dict(cheks["model_state_dict"])
model.eval()
with torch.inference_mode():
logits = model(inp_data, None)
pred = prediction2label(logits).item()
preds.append(pred)
return mean(preds)
if __name__ == "__main__":
mp3_path = "yt_audio.mp3"
model_name = "audio_midi_multi_ps_v5"
pred_multi = predict_difficulty(mp3_path, model_name=model_name, rep="multimodal5")
print(f"Predicción dificultad multimodal: {pred_multi}")