import gradio as gr import torch import os import random import glob import numpy as np import pretty_midi import scipy.io.wavfile # --- Dependencies from your project --- # (Make sure these files are in the same directory) try: from model.music_transformer import MusicTransformer from processor import encode_midi, decode_midi from dataset.e_piano import process_midi from utilities.constants import * from utilities.device import get_device, use_cuda except ImportError as e: print(f"Error: Could not import necessary files.") print(f"Make sure app.py is in the same folder as 'model', 'processor.py', etc.") print(f"Details: {e}") exit() # --- Your Model's Hyperparameters --- # (Pulled from your training logs) MODEL_CONFIG = { "n_layers": 6, "num_heads": 8, "d_model": 512, "dim_feedforward": 1024, "max_sequence": 2048, "rpr": True } # ------------------------------------ # Global variable to hold the loaded model model = None device = get_device() print(f"Using device: {device}") def load_model(model_path): """ Loads the trained MusicTransformer model into memory. """ global model if model_path is None or not os.path.exists(model_path): return "Error: Model file not found. Please check the path." try: print("Loading model...") model = MusicTransformer( n_layers=MODEL_CONFIG["n_layers"], num_heads=MODEL_CONFIG["num_heads"], d_model=MODEL_CONFIG["d_model"], dim_feedforward=MODEL_CONFIG["dim_feedforward"], max_sequence=MODEL_CONFIG["max_sequence"], rpr=MODEL_CONFIG["rpr"] ).to(device) # Load the weights, mapping to the correct device model.load_state_dict( torch.load(model_path, map_location=device, weights_only=True) ) model.eval() print("Model loaded successfully.") return f"Model '{model_path}' loaded successfully." except Exception as e: return f"Error loading model: {e}" # --- NEW FUNCTION --- def midi_to_wav(midi_file_path, wav_file_path): """ Synthesizes a MIDI file to a WAV file using pretty_midi's built-in (simple) sine wave synthesizer. """ try: pm = pretty_midi.PrettyMIDI(midi_file_path) # Synthesize the audio at a 44.1kHz sample rate audio_data = pm.synthesize(fs=44100) # Write as a 16-bit WAV file scipy.io.wavfile.write(wav_file_path, 44100, audio_data.astype(np.int16)) return wav_file_path except Exception as e: print(f"Error during MIDI to WAV conversion: {e}") return None # --- END NEW FUNCTION --- def generate_music(primer_type, uploaded_midi, upload_start_location, maestro_path, maestro_start_location, primer_length, generation_length_new, progress=gr.Progress(track_tqdm=True)): """ The main function called by the Gradio button. """ global model if model is None: # --- MODIFICATION: Return 3 values on error --- yield "Error: Model is not loaded. Please load a model first.", None, None try: # --- 1. Prepare the Primer --- primer = None num_primer = 0 total_target_length = primer_length + generation_length_new if total_target_length > MODEL_CONFIG["max_sequence"]: total_target_length = MODEL_CONFIG["max_sequence"] yield f"Warning: Clamping to {total_target_length} tokens.", None, None if primer_type == "Generate from Silence": yield "Generating from silence...", None, None primer = torch.tensor([372], dtype=TORCH_LABEL_TYPE, device=device) num_primer = 1 elif primer_type == "Random Maestro MIDI": yield "Finding random Maestro file...", None, None if maestro_path is None or not os.path.isdir(maestro_path): yield f"Error: Maestro path '{maestro_path}' is not valid.", None, None return midi_files = glob.glob(os.path.join(maestro_path, "**", "*.mid"), recursive=True) + \ glob.glob(os.path.join(maestro_path, "**", "*.midi"), recursive=True) if not midi_files: yield f"Error: No .mid/.midi files found in '{maestro_path}'.", None, None return random_file = random.choice(midi_files) yield f"Tokenizing random file: {os.path.basename(random_file)}...", None, None raw_mid = encode_midi(random_file) is_random_start = (maestro_start_location == "Random Location") primer_tokens, _ = process_midi(raw_mid, primer_length, random_seq=is_random_start) primer = torch.tensor(primer_tokens, dtype=TORCH_LABEL_TYPE, device=device) num_primer = primer.shape[0] elif primer_type == "Upload My Own MIDI": if uploaded_midi is None: yield "Error: Please upload a MIDI file.", None, None return yield f"Tokenizing uploaded MIDI: {os.path.basename(uploaded_midi.name)}...", None, None raw_mid = encode_midi(uploaded_midi.name) if not raw_mid: yield "Error: Could not read MIDI messages.", None, None return is_random_start = (upload_start_location == "Random Location") primer_tokens, _ = process_midi(raw_mid, primer_length, random_seq=is_random_start) primer = torch.tensor(primer_tokens, dtype=TORCH_LABEL_TYPE, device=device) num_primer = primer.shape[0] if num_primer == 0: yield "Error: Primer processing resulted in 0 tokens.", None, None return # --- 2. Run Generation --- yield f"Primed with {num_primer} tokens. Generating {generation_length_new} new tokens...", None, None primer_batch = primer.unsqueeze(0) model.eval() with torch.set_grad_enabled(False): rand_seq = model.generate(primer_batch, total_target_length, beam=0) # --- 3. Process and Save Output --- generated_only_tokens = rand_seq[0][num_primer:] if len(generated_only_tokens) == 0: yield "Warning: Generation produced 0 new tokens.", None, None return # --- MODIFICATION: Define output paths --- midi_output_filename = "generation_output.mid" wav_output_filename = "generation_output.wav" # Save the MIDI file decode_midi(generated_only_tokens.cpu().numpy(), midi_output_filename) # --- MODIFICATION: Synthesize MIDI to WAV --- yield "Synthesizing audio...", midi_output_filename, None wav_path = midi_to_wav(midi_output_filename, wav_output_filename) if wav_path: yield "Generation Complete!", midi_output_filename, wav_path else: yield "Generation complete (WAV synthesis failed).", midi_output_filename, None except Exception as e: yield f"An error occurred: {e}", None, None # --- Build the Gradio UI --- with gr.Blocks(theme=gr.themes.Soft()) as app: gr.Markdown("# 🎹 Music Transformer Generation UI") gr.Markdown("Load your trained model and generate music from silence, a random seed, or your own MIDI file.") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 1. Load Model") model_path_input = gr.Textbox( label="Path to your .pickle model file", value="best_acc_weights.pickle" ) load_button = gr.Button("Load Model", variant="primary") load_status = gr.Textbox(label="Model Status", interactive=False) with gr.Column(scale=2): gr.Markdown("### 2. Configure Generation") primer_type_input = gr.Radio( label="Choose Primer Type", choices=["Generate from Silence", "Random Maestro MIDI", "Upload My Own MIDI"], value="Generate from Silence" ) with gr.Column(visible=False) as maestro_options: maestro_path_input = gr.Textbox( label="Path to RAW Maestro MIDI Folder (searches all subfolders)", value="./maestro-v2.0.0" ) maestro_start_location_input = gr.Radio( label="Primer Start Location", choices=["Start of File", "Random Location"], value="Random Location", info="Selects a random chunk from the file, giving more variety." ) with gr.Column(visible=False) as upload_options: uploaded_midi_input = gr.File( label="Upload Your MIDI Primer", file_types=[".mid", ".midi"] ) upload_start_location_input = gr.Radio( label="Primer Start Location", choices=["Start of File", "Random Location"], value="Start of File" ) primer_length_slider = gr.Slider( label="Primer Length (Tokens)", minimum=64, maximum=2000, value=512, step=32, info="How many tokens to use from the primer file. Ignored for 'Silence'." ) generation_length_slider = gr.Slider( label="New Tokens to Generate", minimum=128, maximum=2048, value=1024, step=32, info="How many new tokens to create after the primer." ) generate_button = gr.Button("Generate Music", variant="primary") with gr.Row(): gr.Markdown("### 3. Get Your Music") status_output = gr.Textbox(label="Status", interactive=False) with gr.Row(): output_midi_file = gr.File(label="Download Generated MIDI") # --- MODIFICATION: Added Audio player --- output_wav_file = gr.Audio(label="Listen to Generated WAV", type="filepath") # --- END MODIFICATION --- # --- UI Event Listeners --- def update_ui(primer_type): return { maestro_options: gr.Column(visible=(primer_type == "Random Maestro MIDI")), upload_options: gr.Column(visible=(primer_type == "Upload My Own MIDI")), primer_length_slider: gr.Slider(visible=(primer_type != "Generate from Silence")) } primer_type_input.change( fn=update_ui, inputs=primer_type_input, outputs=[maestro_options, upload_options, primer_length_slider] ) load_button.click( fn=load_model, inputs=model_path_input, outputs=load_status ) # --- MODIFICATION: Updated outputs list --- generate_button.click( fn=generate_music, inputs=[ primer_type_input, uploaded_midi_input, upload_start_location_input, maestro_path_input, maestro_start_location_input, primer_length_slider, generation_length_slider ], outputs=[status_output, output_midi_file, output_wav_file] # <-- Added WAV output ) # --- END MODIFICATION --- if __name__ == "__main__": # Check if CUDA is available and set device if (not torch.cuda.is_available()): print("----- WARNING: CUDA devices not detected. This will cause the model to run very slow! -----") use_cuda(False) print("Launching Gradio UI...") app.launch()