Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import pdb | |
| from statistics import mean | |
| import torch | |
| from torch import nn | |
| import numpy as np | |
| import librosa | |
| from piano_transcription_inference import PianoTranscription, sample_rate, load_audio | |
| import pretty_midi | |
| from utils import prediction2label | |
| from model import AudioModel | |
| from scipy.signal import resample | |
| def downsample_log_cqt(cqt_matrix, target_fs=5): | |
| original_fs = 44100 / 160 | |
| ratio = original_fs / target_fs | |
| downsampled = resample(cqt_matrix, int(cqt_matrix.shape[0] / ratio), axis=0) | |
| return downsampled | |
| def downsample_matrix(mat, original_fs, target_fs): | |
| ratio = original_fs / target_fs | |
| return resample(mat, int(mat.shape[0] / ratio), axis=0) | |
| def get_cqt_from_mp3(mp3_path): | |
| sample_rate = 44100 | |
| hop_length = 160 | |
| y, sr = librosa.load(mp3_path, sr=sample_rate, mono=True) | |
| cqt = librosa.cqt(y, sr=sr, hop_length=hop_length, n_bins=88, bins_per_octave=12) | |
| log_cqt = librosa.amplitude_to_db(np.abs(cqt)) | |
| log_cqt = log_cqt.T # shape (T, 88) | |
| log_cqt = downsample_log_cqt(log_cqt, target_fs=5) | |
| cqt_tensor = torch.tensor(log_cqt, dtype=torch.float32).unsqueeze(0).unsqueeze(0).cpu() | |
| print(f"cqt shape: {log_cqt.shape}") | |
| return cqt_tensor | |
| def get_pianoroll_from_mp3(mp3_path): | |
| audio, _ = load_audio(mp3_path, sr=sample_rate, mono=True) | |
| transcriptor = PianoTranscription(device="cuda" if torch.cuda.is_available() else "cpu") | |
| midi_path = "temp.mid" | |
| transcriptor.transcribe(audio, midi_path) | |
| midi_data = pretty_midi.PrettyMIDI(midi_path) | |
| fs = 5 # original frames per second | |
| piano_roll = midi_data.get_piano_roll(fs=fs)[21:109].T # shape: (T, 88) | |
| piano_roll = piano_roll / 127 | |
| time_steps = piano_roll.shape[0] | |
| onsets = np.zeros_like(piano_roll) | |
| for instrument in midi_data.instruments: | |
| for note in instrument.notes: | |
| pitch = note.pitch - 21 | |
| onset_frame = int(note.start * fs) | |
| if 0 <= pitch < 88 and onset_frame < time_steps: | |
| onsets[onset_frame, pitch] = 1.0 | |
| pr_tensor = torch.tensor(piano_roll.T).unsqueeze(0).unsqueeze(1).cpu().float() | |
| on_tensor = torch.tensor(onsets.T).unsqueeze(0).unsqueeze(1).cpu().float() | |
| out_tensor = torch.cat([pr_tensor, on_tensor], dim=1) | |
| print(f"piano_roll shape: {out_tensor.shape}") | |
| return out_tensor.transpose(2, 3) | |
| def predict_difficulty(mp3_path, model_name, rep): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if "only_cqt" in rep: | |
| only_cqt, only_pr = True, False | |
| rep_clean = "multimodal5" | |
| elif "only_pr" in rep: | |
| only_cqt, only_pr = False, True | |
| rep_clean = "multimodal5" | |
| else: | |
| only_cqt = only_pr = False | |
| rep_clean = rep | |
| model = AudioModel(num_classes=11, rep=rep_clean, modality_dropout=False, only_cqt=only_cqt, only_pr=only_pr).to(device) | |
| checkpoint = [torch.load(f"models/{model_name}/checkpoint_{i}.pth", map_location=device, weights_only=False) | |
| for i in range(5)] | |
| if rep == "cqt5": | |
| inp_data = get_cqt_from_mp3(mp3_path).to(device) | |
| elif rep == "pianoroll5": | |
| inp_data = get_pianoroll_from_mp3(mp3_path).to(device) | |
| elif rep_clean == "multimodal5": | |
| x1 = get_pianoroll_from_mp3(mp3_path).to(device) | |
| x2 = get_cqt_from_mp3(mp3_path).to(device) | |
| inp_data = [x1, x2] | |
| else: | |
| raise ValueError(f"Representation {rep} not supported") | |
| preds = [] | |
| for cheks in checkpoint: | |
| model.load_state_dict(cheks["model_state_dict"]) | |
| model.eval() | |
| with torch.inference_mode(): | |
| logits = model(inp_data, None) | |
| pred = prediction2label(logits).item() | |
| preds.append(pred) | |
| return mean(preds) | |
| if __name__ == "__main__": | |
| mp3_path = "yt_audio.mp3" | |
| model_name = "audio_midi_multi_ps_v5" | |
| pred_multi = predict_difficulty(mp3_path, model_name=model_name, rep="multimodal5") | |
| print(f"Multimodal: {pred_multi}") | |
| model_name = "audio_midi_pianoroll_ps_5_v4" | |
| pred_multi = predict_difficulty(mp3_path, model_name=model_name, rep="pianoroll5") | |
| print(f"Pianoroll: {pred_multi}") | |
| model_name = "audio_midi_multi_ps_v5" | |
| pred_multi = predict_difficulty(mp3_path, model_name=model_name, rep="pianoroll5") | |
| print(f"CQT: {pred_multi}") | |