|
|
import gradio as gr
|
|
|
import torch
|
|
|
import os
|
|
|
import random
|
|
|
import glob
|
|
|
import numpy as np
|
|
|
import pretty_midi
|
|
|
import scipy.io.wavfile
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
from model.music_transformer import MusicTransformer
|
|
|
from processor import encode_midi, decode_midi
|
|
|
from dataset.e_piano import process_midi
|
|
|
from utilities.constants import *
|
|
|
from utilities.device import get_device, use_cuda
|
|
|
except ImportError as e:
|
|
|
print(f"Error: Could not import necessary files.")
|
|
|
print(f"Make sure app.py is in the same folder as 'model', 'processor.py', etc.")
|
|
|
print(f"Details: {e}")
|
|
|
exit()
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_CONFIG = {
|
|
|
"n_layers": 6,
|
|
|
"num_heads": 8,
|
|
|
"d_model": 512,
|
|
|
"dim_feedforward": 1024,
|
|
|
"max_sequence": 2048,
|
|
|
"rpr": True
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
model = None
|
|
|
device = get_device()
|
|
|
print(f"Using device: {device}")
|
|
|
|
|
|
|
|
|
def load_model(model_path):
|
|
|
"""
|
|
|
Loads the trained MusicTransformer model into memory.
|
|
|
"""
|
|
|
global model
|
|
|
if model_path is None or not os.path.exists(model_path):
|
|
|
return "Error: Model file not found. Please check the path."
|
|
|
|
|
|
try:
|
|
|
print("Loading model...")
|
|
|
model = MusicTransformer(
|
|
|
n_layers=MODEL_CONFIG["n_layers"],
|
|
|
num_heads=MODEL_CONFIG["num_heads"],
|
|
|
d_model=MODEL_CONFIG["d_model"],
|
|
|
dim_feedforward=MODEL_CONFIG["dim_feedforward"],
|
|
|
max_sequence=MODEL_CONFIG["max_sequence"],
|
|
|
rpr=MODEL_CONFIG["rpr"]
|
|
|
).to(device)
|
|
|
|
|
|
|
|
|
model.load_state_dict(
|
|
|
torch.load(model_path, map_location=device, weights_only=True)
|
|
|
)
|
|
|
model.eval()
|
|
|
print("Model loaded successfully.")
|
|
|
return f"Model '{model_path}' loaded successfully."
|
|
|
except Exception as e:
|
|
|
return f"Error loading model: {e}"
|
|
|
|
|
|
|
|
|
|
|
|
def midi_to_wav(midi_file_path, wav_file_path):
|
|
|
"""
|
|
|
Synthesizes a MIDI file to a WAV file using pretty_midi's
|
|
|
built-in (simple) sine wave synthesizer.
|
|
|
"""
|
|
|
try:
|
|
|
pm = pretty_midi.PrettyMIDI(midi_file_path)
|
|
|
|
|
|
audio_data = pm.synthesize(fs=44100)
|
|
|
|
|
|
scipy.io.wavfile.write(wav_file_path, 44100, audio_data.astype(np.int16))
|
|
|
return wav_file_path
|
|
|
except Exception as e:
|
|
|
print(f"Error during MIDI to WAV conversion: {e}")
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_music(primer_type, uploaded_midi, upload_start_location, maestro_path, maestro_start_location,
|
|
|
primer_length, generation_length_new, progress=gr.Progress(track_tqdm=True)):
|
|
|
"""
|
|
|
The main function called by the Gradio button.
|
|
|
"""
|
|
|
global model
|
|
|
if model is None:
|
|
|
|
|
|
yield "Error: Model is not loaded. Please load a model first.", None, None
|
|
|
|
|
|
try:
|
|
|
|
|
|
primer = None
|
|
|
num_primer = 0
|
|
|
|
|
|
total_target_length = primer_length + generation_length_new
|
|
|
if total_target_length > MODEL_CONFIG["max_sequence"]:
|
|
|
total_target_length = MODEL_CONFIG["max_sequence"]
|
|
|
yield f"Warning: Clamping to {total_target_length} tokens.", None, None
|
|
|
|
|
|
if primer_type == "Generate from Silence":
|
|
|
yield "Generating from silence...", None, None
|
|
|
primer = torch.tensor([372], dtype=TORCH_LABEL_TYPE, device=device)
|
|
|
num_primer = 1
|
|
|
|
|
|
elif primer_type == "Random Maestro MIDI":
|
|
|
yield "Finding random Maestro file...", None, None
|
|
|
if maestro_path is None or not os.path.isdir(maestro_path):
|
|
|
yield f"Error: Maestro path '{maestro_path}' is not valid.", None, None
|
|
|
return
|
|
|
|
|
|
midi_files = glob.glob(os.path.join(maestro_path, "**", "*.mid"), recursive=True) + \
|
|
|
glob.glob(os.path.join(maestro_path, "**", "*.midi"), recursive=True)
|
|
|
|
|
|
if not midi_files:
|
|
|
yield f"Error: No .mid/.midi files found in '{maestro_path}'.", None, None
|
|
|
return
|
|
|
|
|
|
random_file = random.choice(midi_files)
|
|
|
yield f"Tokenizing random file: {os.path.basename(random_file)}...", None, None
|
|
|
raw_mid = encode_midi(random_file)
|
|
|
|
|
|
is_random_start = (maestro_start_location == "Random Location")
|
|
|
primer_tokens, _ = process_midi(raw_mid, primer_length, random_seq=is_random_start)
|
|
|
|
|
|
primer = torch.tensor(primer_tokens, dtype=TORCH_LABEL_TYPE, device=device)
|
|
|
num_primer = primer.shape[0]
|
|
|
|
|
|
elif primer_type == "Upload My Own MIDI":
|
|
|
if uploaded_midi is None:
|
|
|
yield "Error: Please upload a MIDI file.", None, None
|
|
|
return
|
|
|
|
|
|
yield f"Tokenizing uploaded MIDI: {os.path.basename(uploaded_midi.name)}...", None, None
|
|
|
raw_mid = encode_midi(uploaded_midi.name)
|
|
|
if not raw_mid:
|
|
|
yield "Error: Could not read MIDI messages.", None, None
|
|
|
return
|
|
|
|
|
|
is_random_start = (upload_start_location == "Random Location")
|
|
|
primer_tokens, _ = process_midi(raw_mid, primer_length, random_seq=is_random_start)
|
|
|
primer = torch.tensor(primer_tokens, dtype=TORCH_LABEL_TYPE, device=device)
|
|
|
num_primer = primer.shape[0]
|
|
|
|
|
|
if num_primer == 0:
|
|
|
yield "Error: Primer processing resulted in 0 tokens.", None, None
|
|
|
return
|
|
|
|
|
|
|
|
|
yield f"Primed with {num_primer} tokens. Generating {generation_length_new} new tokens...", None, None
|
|
|
|
|
|
primer_batch = primer.unsqueeze(0)
|
|
|
|
|
|
model.eval()
|
|
|
with torch.set_grad_enabled(False):
|
|
|
rand_seq = model.generate(primer_batch, total_target_length, beam=0)
|
|
|
|
|
|
|
|
|
generated_only_tokens = rand_seq[0][num_primer:]
|
|
|
|
|
|
if len(generated_only_tokens) == 0:
|
|
|
yield "Warning: Generation produced 0 new tokens.", None, None
|
|
|
return
|
|
|
|
|
|
|
|
|
midi_output_filename = "generation_output.mid"
|
|
|
wav_output_filename = "generation_output.wav"
|
|
|
|
|
|
|
|
|
decode_midi(generated_only_tokens.cpu().numpy(), midi_output_filename)
|
|
|
|
|
|
|
|
|
yield "Synthesizing audio...", midi_output_filename, None
|
|
|
wav_path = midi_to_wav(midi_output_filename, wav_output_filename)
|
|
|
|
|
|
if wav_path:
|
|
|
yield "Generation Complete!", midi_output_filename, wav_path
|
|
|
else:
|
|
|
yield "Generation complete (WAV synthesis failed).", midi_output_filename, None
|
|
|
|
|
|
except Exception as e:
|
|
|
yield f"An error occurred: {e}", None, None
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as app:
|
|
|
gr.Markdown("# 🎹 Music Transformer Generation UI")
|
|
|
gr.Markdown("Load your trained model and generate music from silence, a random seed, or your own MIDI file.")
|
|
|
|
|
|
with gr.Row():
|
|
|
with gr.Column(scale=1):
|
|
|
gr.Markdown("### 1. Load Model")
|
|
|
model_path_input = gr.Textbox(
|
|
|
label="Path to your .pickle model file",
|
|
|
value="best_acc_weights.pickle"
|
|
|
)
|
|
|
load_button = gr.Button("Load Model", variant="primary")
|
|
|
load_status = gr.Textbox(label="Model Status", interactive=False)
|
|
|
|
|
|
with gr.Column(scale=2):
|
|
|
gr.Markdown("### 2. Configure Generation")
|
|
|
|
|
|
primer_type_input = gr.Radio(
|
|
|
label="Choose Primer Type",
|
|
|
choices=["Generate from Silence", "Random Maestro MIDI", "Upload My Own MIDI"],
|
|
|
value="Generate from Silence"
|
|
|
)
|
|
|
|
|
|
with gr.Column(visible=False) as maestro_options:
|
|
|
maestro_path_input = gr.Textbox(
|
|
|
label="Path to RAW Maestro MIDI Folder (searches all subfolders)",
|
|
|
value="./maestro-v2.0.0"
|
|
|
)
|
|
|
maestro_start_location_input = gr.Radio(
|
|
|
label="Primer Start Location",
|
|
|
choices=["Start of File", "Random Location"],
|
|
|
value="Random Location",
|
|
|
info="Selects a random chunk from the file, giving more variety."
|
|
|
)
|
|
|
|
|
|
with gr.Column(visible=False) as upload_options:
|
|
|
uploaded_midi_input = gr.File(
|
|
|
label="Upload Your MIDI Primer",
|
|
|
file_types=[".mid", ".midi"]
|
|
|
)
|
|
|
upload_start_location_input = gr.Radio(
|
|
|
label="Primer Start Location",
|
|
|
choices=["Start of File", "Random Location"],
|
|
|
value="Start of File"
|
|
|
)
|
|
|
|
|
|
primer_length_slider = gr.Slider(
|
|
|
label="Primer Length (Tokens)",
|
|
|
minimum=64,
|
|
|
maximum=2000,
|
|
|
value=512,
|
|
|
step=32,
|
|
|
info="How many tokens to use from the primer file. Ignored for 'Silence'."
|
|
|
)
|
|
|
|
|
|
generation_length_slider = gr.Slider(
|
|
|
label="New Tokens to Generate",
|
|
|
minimum=128,
|
|
|
maximum=2048,
|
|
|
value=1024,
|
|
|
step=32,
|
|
|
info="How many new tokens to create after the primer."
|
|
|
)
|
|
|
|
|
|
generate_button = gr.Button("Generate Music", variant="primary")
|
|
|
|
|
|
with gr.Row():
|
|
|
gr.Markdown("### 3. Get Your Music")
|
|
|
status_output = gr.Textbox(label="Status", interactive=False)
|
|
|
with gr.Row():
|
|
|
output_midi_file = gr.File(label="Download Generated MIDI")
|
|
|
|
|
|
output_wav_file = gr.Audio(label="Listen to Generated WAV", type="filepath")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_ui(primer_type):
|
|
|
return {
|
|
|
maestro_options: gr.Column(visible=(primer_type == "Random Maestro MIDI")),
|
|
|
upload_options: gr.Column(visible=(primer_type == "Upload My Own MIDI")),
|
|
|
primer_length_slider: gr.Slider(visible=(primer_type != "Generate from Silence"))
|
|
|
}
|
|
|
|
|
|
|
|
|
primer_type_input.change(
|
|
|
fn=update_ui,
|
|
|
inputs=primer_type_input,
|
|
|
outputs=[maestro_options, upload_options, primer_length_slider]
|
|
|
)
|
|
|
|
|
|
load_button.click(
|
|
|
fn=load_model,
|
|
|
inputs=model_path_input,
|
|
|
outputs=load_status
|
|
|
)
|
|
|
|
|
|
|
|
|
generate_button.click(
|
|
|
fn=generate_music,
|
|
|
inputs=[
|
|
|
primer_type_input,
|
|
|
uploaded_midi_input,
|
|
|
upload_start_location_input,
|
|
|
maestro_path_input,
|
|
|
maestro_start_location_input,
|
|
|
primer_length_slider,
|
|
|
generation_length_slider
|
|
|
],
|
|
|
outputs=[status_output, output_midi_file, output_wav_file]
|
|
|
)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
if (not torch.cuda.is_available()):
|
|
|
print("----- WARNING: CUDA devices not detected. This will cause the model to run very slow! -----")
|
|
|
use_cuda(False)
|
|
|
|
|
|
print("Launching Gradio UI...")
|
|
|
app.launch() |