NanoMaestro Full model weights released
Browse files- app.py +318 -0
- best_acc_weights.pickle +3 -0
- dataset/__init__.py +0 -0
- dataset/__pycache__/__init__.cpython-312.pyc +0 -0
- dataset/__pycache__/e_piano.cpython-312.pyc +0 -0
- dataset/e_piano.py +169 -0
- generate.py +134 -0
- model/__init__.py +0 -0
- model/__pycache__/__init__.cpython-312.pyc +0 -0
- model/__pycache__/__init__.cpython-313.pyc +0 -0
- model/__pycache__/music_transformer.cpython-312.pyc +0 -0
- model/__pycache__/music_transformer.cpython-313.pyc +0 -0
- model/__pycache__/positional_encoding.cpython-312.pyc +0 -0
- model/__pycache__/positional_encoding.cpython-313.pyc +0 -0
- model/__pycache__/rpr.cpython-312.pyc +0 -0
- model/__pycache__/rpr.cpython-313.pyc +0 -0
- model/music_transformer.py +135 -0
- model/positional_encoding.py +23 -0
- model/rpr.py +171 -0
- processor.py +267 -0
- utilities/__init__.py +0 -0
- utilities/__pycache__/__init__.cpython-312.pyc +0 -0
- utilities/__pycache__/__init__.cpython-313.pyc +0 -0
- utilities/__pycache__/argument_funcs.cpython-312.pyc +0 -0
- utilities/__pycache__/constants.cpython-312.pyc +0 -0
- utilities/__pycache__/constants.cpython-313.pyc +0 -0
- utilities/__pycache__/device.cpython-312.pyc +0 -0
- utilities/__pycache__/device.cpython-313.pyc +0 -0
- utilities/argument_funcs.py +228 -0
- utilities/constants.py +28 -0
- utilities/device.py +67 -0
app.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
import glob
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pretty_midi
|
| 8 |
+
import scipy.io.wavfile
|
| 9 |
+
|
| 10 |
+
# --- Dependencies from your project ---
|
| 11 |
+
# (Make sure these files are in the same directory)
|
| 12 |
+
try:
|
| 13 |
+
from model.music_transformer import MusicTransformer
|
| 14 |
+
from processor import encode_midi, decode_midi
|
| 15 |
+
from dataset.e_piano import process_midi
|
| 16 |
+
from utilities.constants import *
|
| 17 |
+
from utilities.device import get_device, use_cuda
|
| 18 |
+
except ImportError as e:
|
| 19 |
+
print(f"Error: Could not import necessary files.")
|
| 20 |
+
print(f"Make sure app.py is in the same folder as 'model', 'processor.py', etc.")
|
| 21 |
+
print(f"Details: {e}")
|
| 22 |
+
exit()
|
| 23 |
+
|
| 24 |
+
# --- Your Model's Hyperparameters ---
|
| 25 |
+
# (Pulled from your training logs)
|
| 26 |
+
MODEL_CONFIG = {
|
| 27 |
+
"n_layers": 6,
|
| 28 |
+
"num_heads": 8,
|
| 29 |
+
"d_model": 512,
|
| 30 |
+
"dim_feedforward": 1024,
|
| 31 |
+
"max_sequence": 2048,
|
| 32 |
+
"rpr": True
|
| 33 |
+
}
|
| 34 |
+
# ------------------------------------
|
| 35 |
+
|
| 36 |
+
# Global variable to hold the loaded model
|
| 37 |
+
model = None
|
| 38 |
+
device = get_device()
|
| 39 |
+
print(f"Using device: {device}")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def load_model(model_path):
|
| 43 |
+
"""
|
| 44 |
+
Loads the trained MusicTransformer model into memory.
|
| 45 |
+
"""
|
| 46 |
+
global model
|
| 47 |
+
if model_path is None or not os.path.exists(model_path):
|
| 48 |
+
return "Error: Model file not found. Please check the path."
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
print("Loading model...")
|
| 52 |
+
model = MusicTransformer(
|
| 53 |
+
n_layers=MODEL_CONFIG["n_layers"],
|
| 54 |
+
num_heads=MODEL_CONFIG["num_heads"],
|
| 55 |
+
d_model=MODEL_CONFIG["d_model"],
|
| 56 |
+
dim_feedforward=MODEL_CONFIG["dim_feedforward"],
|
| 57 |
+
max_sequence=MODEL_CONFIG["max_sequence"],
|
| 58 |
+
rpr=MODEL_CONFIG["rpr"]
|
| 59 |
+
).to(device)
|
| 60 |
+
|
| 61 |
+
# Load the weights, mapping to the correct device
|
| 62 |
+
model.load_state_dict(
|
| 63 |
+
torch.load(model_path, map_location=device, weights_only=True)
|
| 64 |
+
)
|
| 65 |
+
model.eval()
|
| 66 |
+
print("Model loaded successfully.")
|
| 67 |
+
return f"Model '{model_path}' loaded successfully."
|
| 68 |
+
except Exception as e:
|
| 69 |
+
return f"Error loading model: {e}"
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# --- NEW FUNCTION ---
|
| 73 |
+
def midi_to_wav(midi_file_path, wav_file_path):
|
| 74 |
+
"""
|
| 75 |
+
Synthesizes a MIDI file to a WAV file using pretty_midi's
|
| 76 |
+
built-in (simple) sine wave synthesizer.
|
| 77 |
+
"""
|
| 78 |
+
try:
|
| 79 |
+
pm = pretty_midi.PrettyMIDI(midi_file_path)
|
| 80 |
+
# Synthesize the audio at a 44.1kHz sample rate
|
| 81 |
+
audio_data = pm.synthesize(fs=44100)
|
| 82 |
+
# Write as a 16-bit WAV file
|
| 83 |
+
scipy.io.wavfile.write(wav_file_path, 44100, audio_data.astype(np.int16))
|
| 84 |
+
return wav_file_path
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f"Error during MIDI to WAV conversion: {e}")
|
| 87 |
+
return None
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# --- END NEW FUNCTION ---
|
| 91 |
+
|
| 92 |
+
def generate_music(primer_type, uploaded_midi, upload_start_location, maestro_path, maestro_start_location,
|
| 93 |
+
primer_length, generation_length_new, progress=gr.Progress(track_tqdm=True)):
|
| 94 |
+
"""
|
| 95 |
+
The main function called by the Gradio button.
|
| 96 |
+
"""
|
| 97 |
+
global model
|
| 98 |
+
if model is None:
|
| 99 |
+
# --- MODIFICATION: Return 3 values on error ---
|
| 100 |
+
yield "Error: Model is not loaded. Please load a model first.", None, None
|
| 101 |
+
|
| 102 |
+
try:
|
| 103 |
+
# --- 1. Prepare the Primer ---
|
| 104 |
+
primer = None
|
| 105 |
+
num_primer = 0
|
| 106 |
+
|
| 107 |
+
total_target_length = primer_length + generation_length_new
|
| 108 |
+
if total_target_length > MODEL_CONFIG["max_sequence"]:
|
| 109 |
+
total_target_length = MODEL_CONFIG["max_sequence"]
|
| 110 |
+
yield f"Warning: Clamping to {total_target_length} tokens.", None, None
|
| 111 |
+
|
| 112 |
+
if primer_type == "Generate from Silence":
|
| 113 |
+
yield "Generating from silence...", None, None
|
| 114 |
+
primer = torch.tensor([372], dtype=TORCH_LABEL_TYPE, device=device)
|
| 115 |
+
num_primer = 1
|
| 116 |
+
|
| 117 |
+
elif primer_type == "Random Maestro MIDI":
|
| 118 |
+
yield "Finding random Maestro file...", None, None
|
| 119 |
+
if maestro_path is None or not os.path.isdir(maestro_path):
|
| 120 |
+
yield f"Error: Maestro path '{maestro_path}' is not valid.", None, None
|
| 121 |
+
return
|
| 122 |
+
|
| 123 |
+
midi_files = glob.glob(os.path.join(maestro_path, "**", "*.mid"), recursive=True) + \
|
| 124 |
+
glob.glob(os.path.join(maestro_path, "**", "*.midi"), recursive=True)
|
| 125 |
+
|
| 126 |
+
if not midi_files:
|
| 127 |
+
yield f"Error: No .mid/.midi files found in '{maestro_path}'.", None, None
|
| 128 |
+
return
|
| 129 |
+
|
| 130 |
+
random_file = random.choice(midi_files)
|
| 131 |
+
yield f"Tokenizing random file: {os.path.basename(random_file)}...", None, None
|
| 132 |
+
raw_mid = encode_midi(random_file)
|
| 133 |
+
|
| 134 |
+
is_random_start = (maestro_start_location == "Random Location")
|
| 135 |
+
primer_tokens, _ = process_midi(raw_mid, primer_length, random_seq=is_random_start)
|
| 136 |
+
|
| 137 |
+
primer = torch.tensor(primer_tokens, dtype=TORCH_LABEL_TYPE, device=device)
|
| 138 |
+
num_primer = primer.shape[0]
|
| 139 |
+
|
| 140 |
+
elif primer_type == "Upload My Own MIDI":
|
| 141 |
+
if uploaded_midi is None:
|
| 142 |
+
yield "Error: Please upload a MIDI file.", None, None
|
| 143 |
+
return
|
| 144 |
+
|
| 145 |
+
yield f"Tokenizing uploaded MIDI: {os.path.basename(uploaded_midi.name)}...", None, None
|
| 146 |
+
raw_mid = encode_midi(uploaded_midi.name)
|
| 147 |
+
if not raw_mid:
|
| 148 |
+
yield "Error: Could not read MIDI messages.", None, None
|
| 149 |
+
return
|
| 150 |
+
|
| 151 |
+
is_random_start = (upload_start_location == "Random Location")
|
| 152 |
+
primer_tokens, _ = process_midi(raw_mid, primer_length, random_seq=is_random_start)
|
| 153 |
+
primer = torch.tensor(primer_tokens, dtype=TORCH_LABEL_TYPE, device=device)
|
| 154 |
+
num_primer = primer.shape[0]
|
| 155 |
+
|
| 156 |
+
if num_primer == 0:
|
| 157 |
+
yield "Error: Primer processing resulted in 0 tokens.", None, None
|
| 158 |
+
return
|
| 159 |
+
|
| 160 |
+
# --- 2. Run Generation ---
|
| 161 |
+
yield f"Primed with {num_primer} tokens. Generating {generation_length_new} new tokens...", None, None
|
| 162 |
+
|
| 163 |
+
primer_batch = primer.unsqueeze(0)
|
| 164 |
+
|
| 165 |
+
model.eval()
|
| 166 |
+
with torch.set_grad_enabled(False):
|
| 167 |
+
rand_seq = model.generate(primer_batch, total_target_length, beam=0)
|
| 168 |
+
|
| 169 |
+
# --- 3. Process and Save Output ---
|
| 170 |
+
generated_only_tokens = rand_seq[0][num_primer:]
|
| 171 |
+
|
| 172 |
+
if len(generated_only_tokens) == 0:
|
| 173 |
+
yield "Warning: Generation produced 0 new tokens.", None, None
|
| 174 |
+
return
|
| 175 |
+
|
| 176 |
+
# --- MODIFICATION: Define output paths ---
|
| 177 |
+
midi_output_filename = "generation_output.mid"
|
| 178 |
+
wav_output_filename = "generation_output.wav"
|
| 179 |
+
|
| 180 |
+
# Save the MIDI file
|
| 181 |
+
decode_midi(generated_only_tokens.cpu().numpy(), midi_output_filename)
|
| 182 |
+
|
| 183 |
+
# --- MODIFICATION: Synthesize MIDI to WAV ---
|
| 184 |
+
yield "Synthesizing audio...", midi_output_filename, None
|
| 185 |
+
wav_path = midi_to_wav(midi_output_filename, wav_output_filename)
|
| 186 |
+
|
| 187 |
+
if wav_path:
|
| 188 |
+
yield "Generation Complete!", midi_output_filename, wav_path
|
| 189 |
+
else:
|
| 190 |
+
yield "Generation complete (WAV synthesis failed).", midi_output_filename, None
|
| 191 |
+
|
| 192 |
+
except Exception as e:
|
| 193 |
+
yield f"An error occurred: {e}", None, None
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
# --- Build the Gradio UI ---
|
| 197 |
+
with gr.Blocks(theme=gr.themes.Soft()) as app:
|
| 198 |
+
gr.Markdown("# 🎹 Music Transformer Generation UI")
|
| 199 |
+
gr.Markdown("Load your trained model and generate music from silence, a random seed, or your own MIDI file.")
|
| 200 |
+
|
| 201 |
+
with gr.Row():
|
| 202 |
+
with gr.Column(scale=1):
|
| 203 |
+
gr.Markdown("### 1. Load Model")
|
| 204 |
+
model_path_input = gr.Textbox(
|
| 205 |
+
label="Path to your .pickle model file",
|
| 206 |
+
value="best_acc_weights.pickle"
|
| 207 |
+
)
|
| 208 |
+
load_button = gr.Button("Load Model", variant="primary")
|
| 209 |
+
load_status = gr.Textbox(label="Model Status", interactive=False)
|
| 210 |
+
|
| 211 |
+
with gr.Column(scale=2):
|
| 212 |
+
gr.Markdown("### 2. Configure Generation")
|
| 213 |
+
|
| 214 |
+
primer_type_input = gr.Radio(
|
| 215 |
+
label="Choose Primer Type",
|
| 216 |
+
choices=["Generate from Silence", "Random Maestro MIDI", "Upload My Own MIDI"],
|
| 217 |
+
value="Generate from Silence"
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
with gr.Column(visible=False) as maestro_options:
|
| 221 |
+
maestro_path_input = gr.Textbox(
|
| 222 |
+
label="Path to RAW Maestro MIDI Folder (searches all subfolders)",
|
| 223 |
+
value="./maestro-v2.0.0"
|
| 224 |
+
)
|
| 225 |
+
maestro_start_location_input = gr.Radio(
|
| 226 |
+
label="Primer Start Location",
|
| 227 |
+
choices=["Start of File", "Random Location"],
|
| 228 |
+
value="Random Location",
|
| 229 |
+
info="Selects a random chunk from the file, giving more variety."
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
with gr.Column(visible=False) as upload_options:
|
| 233 |
+
uploaded_midi_input = gr.File(
|
| 234 |
+
label="Upload Your MIDI Primer",
|
| 235 |
+
file_types=[".mid", ".midi"]
|
| 236 |
+
)
|
| 237 |
+
upload_start_location_input = gr.Radio(
|
| 238 |
+
label="Primer Start Location",
|
| 239 |
+
choices=["Start of File", "Random Location"],
|
| 240 |
+
value="Start of File"
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
primer_length_slider = gr.Slider(
|
| 244 |
+
label="Primer Length (Tokens)",
|
| 245 |
+
minimum=64,
|
| 246 |
+
maximum=2000,
|
| 247 |
+
value=512,
|
| 248 |
+
step=32,
|
| 249 |
+
info="How many tokens to use from the primer file. Ignored for 'Silence'."
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
generation_length_slider = gr.Slider(
|
| 253 |
+
label="New Tokens to Generate",
|
| 254 |
+
minimum=128,
|
| 255 |
+
maximum=2048,
|
| 256 |
+
value=1024,
|
| 257 |
+
step=32,
|
| 258 |
+
info="How many new tokens to create after the primer."
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
generate_button = gr.Button("Generate Music", variant="primary")
|
| 262 |
+
|
| 263 |
+
with gr.Row():
|
| 264 |
+
gr.Markdown("### 3. Get Your Music")
|
| 265 |
+
status_output = gr.Textbox(label="Status", interactive=False)
|
| 266 |
+
with gr.Row():
|
| 267 |
+
output_midi_file = gr.File(label="Download Generated MIDI")
|
| 268 |
+
# --- MODIFICATION: Added Audio player ---
|
| 269 |
+
output_wav_file = gr.Audio(label="Listen to Generated WAV", type="filepath")
|
| 270 |
+
# --- END MODIFICATION ---
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
# --- UI Event Listeners ---
|
| 274 |
+
|
| 275 |
+
def update_ui(primer_type):
|
| 276 |
+
return {
|
| 277 |
+
maestro_options: gr.Column(visible=(primer_type == "Random Maestro MIDI")),
|
| 278 |
+
upload_options: gr.Column(visible=(primer_type == "Upload My Own MIDI")),
|
| 279 |
+
primer_length_slider: gr.Slider(visible=(primer_type != "Generate from Silence"))
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
primer_type_input.change(
|
| 284 |
+
fn=update_ui,
|
| 285 |
+
inputs=primer_type_input,
|
| 286 |
+
outputs=[maestro_options, upload_options, primer_length_slider]
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
load_button.click(
|
| 290 |
+
fn=load_model,
|
| 291 |
+
inputs=model_path_input,
|
| 292 |
+
outputs=load_status
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
# --- MODIFICATION: Updated outputs list ---
|
| 296 |
+
generate_button.click(
|
| 297 |
+
fn=generate_music,
|
| 298 |
+
inputs=[
|
| 299 |
+
primer_type_input,
|
| 300 |
+
uploaded_midi_input,
|
| 301 |
+
upload_start_location_input,
|
| 302 |
+
maestro_path_input,
|
| 303 |
+
maestro_start_location_input,
|
| 304 |
+
primer_length_slider,
|
| 305 |
+
generation_length_slider
|
| 306 |
+
],
|
| 307 |
+
outputs=[status_output, output_midi_file, output_wav_file] # <-- Added WAV output
|
| 308 |
+
)
|
| 309 |
+
# --- END MODIFICATION ---
|
| 310 |
+
|
| 311 |
+
if __name__ == "__main__":
|
| 312 |
+
# Check if CUDA is available and set device
|
| 313 |
+
if (not torch.cuda.is_available()):
|
| 314 |
+
print("----- WARNING: CUDA devices not detected. This will cause the model to run very slow! -----")
|
| 315 |
+
use_cuda(False)
|
| 316 |
+
|
| 317 |
+
print("Launching Gradio UI...")
|
| 318 |
+
app.launch()
|
best_acc_weights.pickle
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bba9e1c91f449753383895c224383dd0ab8402ed003d0d6eb83a4d3f9a3de5df
|
| 3 |
+
size 59442741
|
dataset/__init__.py
ADDED
|
File without changes
|
dataset/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (153 Bytes). View file
|
|
|
dataset/__pycache__/e_piano.cpython-312.pyc
ADDED
|
Binary file (5.88 kB). View file
|
|
|
dataset/e_piano.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
import random
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torch.utils.data import Dataset
|
| 7 |
+
|
| 8 |
+
from utilities.constants import *
|
| 9 |
+
from utilities.device import cpu_device
|
| 10 |
+
|
| 11 |
+
SEQUENCE_START = 0
|
| 12 |
+
|
| 13 |
+
# EPianoDataset
|
| 14 |
+
class EPianoDataset(Dataset):
|
| 15 |
+
"""
|
| 16 |
+
----------
|
| 17 |
+
Author: Damon Gwinn
|
| 18 |
+
----------
|
| 19 |
+
Pytorch Dataset for the Maestro e-piano dataset (https://magenta.tensorflow.org/datasets/maestro).
|
| 20 |
+
Recommended to use with Dataloader (https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)
|
| 21 |
+
|
| 22 |
+
Uses all files found in the given root directory of pre-processed (preprocess_midi.py)
|
| 23 |
+
Maestro midi files.
|
| 24 |
+
----------
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, root, max_seq=2048, random_seq=True):
|
| 28 |
+
self.root = root
|
| 29 |
+
self.max_seq = max_seq
|
| 30 |
+
self.random_seq = random_seq
|
| 31 |
+
|
| 32 |
+
fs = [os.path.join(root, f) for f in os.listdir(self.root)]
|
| 33 |
+
self.data_files = [f for f in fs if os.path.isfile(f)]
|
| 34 |
+
|
| 35 |
+
# __len__
|
| 36 |
+
def __len__(self):
|
| 37 |
+
"""
|
| 38 |
+
----------
|
| 39 |
+
Author: Damon Gwinn
|
| 40 |
+
----------
|
| 41 |
+
How many data files exist in the given directory
|
| 42 |
+
----------
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
return len(self.data_files)
|
| 46 |
+
|
| 47 |
+
# __getitem__
|
| 48 |
+
def __getitem__(self, idx):
|
| 49 |
+
"""
|
| 50 |
+
----------
|
| 51 |
+
Author: Damon Gwinn
|
| 52 |
+
----------
|
| 53 |
+
Gets the indexed midi batch. Gets random sequence or from start depending on random_seq.
|
| 54 |
+
|
| 55 |
+
Returns the input and the target.
|
| 56 |
+
----------
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
# All data on cpu to allow for the Dataloader to multithread
|
| 60 |
+
i_stream = open(self.data_files[idx], "rb")
|
| 61 |
+
# return pickle.load(i_stream), None
|
| 62 |
+
raw_mid = torch.tensor(pickle.load(i_stream), dtype=TORCH_LABEL_TYPE, device=cpu_device())
|
| 63 |
+
i_stream.close()
|
| 64 |
+
|
| 65 |
+
x, tgt = process_midi(raw_mid, self.max_seq, self.random_seq)
|
| 66 |
+
|
| 67 |
+
return x, tgt
|
| 68 |
+
|
| 69 |
+
# process_midi
|
| 70 |
+
def process_midi(raw_mid, max_seq, random_seq):
|
| 71 |
+
"""
|
| 72 |
+
----------
|
| 73 |
+
Author: Damon Gwinn
|
| 74 |
+
----------
|
| 75 |
+
Takes in pre-processed raw midi and returns the input and target. Can use a random sequence or
|
| 76 |
+
go from the start based on random_seq.
|
| 77 |
+
----------
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
x = torch.full((max_seq, ), TOKEN_PAD, dtype=TORCH_LABEL_TYPE, device=cpu_device())
|
| 81 |
+
tgt = torch.full((max_seq, ), TOKEN_PAD, dtype=TORCH_LABEL_TYPE, device=cpu_device())
|
| 82 |
+
|
| 83 |
+
raw_len = len(raw_mid)
|
| 84 |
+
full_seq = max_seq + 1 # Performing seq2seq
|
| 85 |
+
|
| 86 |
+
if(raw_len == 0):
|
| 87 |
+
return x, tgt
|
| 88 |
+
|
| 89 |
+
if(raw_len < full_seq):
|
| 90 |
+
x[:raw_len] = raw_mid
|
| 91 |
+
tgt[:raw_len-1] = raw_mid[1:]
|
| 92 |
+
tgt[raw_len] = TOKEN_END
|
| 93 |
+
else:
|
| 94 |
+
# Randomly selecting a range
|
| 95 |
+
if(random_seq):
|
| 96 |
+
end_range = raw_len - full_seq
|
| 97 |
+
start = random.randint(SEQUENCE_START, end_range)
|
| 98 |
+
|
| 99 |
+
# Always taking from the start to as far as we can
|
| 100 |
+
else:
|
| 101 |
+
start = SEQUENCE_START
|
| 102 |
+
|
| 103 |
+
end = start + full_seq
|
| 104 |
+
|
| 105 |
+
data = raw_mid[start:end]
|
| 106 |
+
|
| 107 |
+
x = data[:max_seq]
|
| 108 |
+
tgt = data[1:full_seq]
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# print("x:",x)
|
| 112 |
+
# print("tgt:",tgt)
|
| 113 |
+
|
| 114 |
+
return x, tgt
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# create_epiano_datasets
|
| 118 |
+
def create_epiano_datasets(dataset_root, max_seq, random_seq=True):
|
| 119 |
+
"""
|
| 120 |
+
----------
|
| 121 |
+
Author: Damon Gwinn
|
| 122 |
+
----------
|
| 123 |
+
Creates train, evaluation, and test EPianoDataset objects for a pre-processed (preprocess_midi.py)
|
| 124 |
+
root containing train, val, and test folders.
|
| 125 |
+
----------
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
train_root = os.path.join(dataset_root, "train")
|
| 129 |
+
val_root = os.path.join(dataset_root, "val")
|
| 130 |
+
test_root = os.path.join(dataset_root, "test")
|
| 131 |
+
|
| 132 |
+
train_dataset = EPianoDataset(train_root, max_seq, random_seq)
|
| 133 |
+
val_dataset = EPianoDataset(val_root, max_seq, random_seq)
|
| 134 |
+
test_dataset = EPianoDataset(test_root, max_seq, random_seq)
|
| 135 |
+
|
| 136 |
+
return train_dataset, val_dataset, test_dataset
|
| 137 |
+
|
| 138 |
+
# compute_epiano_accuracy
|
| 139 |
+
def compute_epiano_accuracy(out, tgt):
|
| 140 |
+
"""
|
| 141 |
+
----------
|
| 142 |
+
Author: Damon Gwinn
|
| 143 |
+
----------
|
| 144 |
+
Computes the average accuracy for the given input and output batches. Accuracy uses softmax
|
| 145 |
+
of the output.
|
| 146 |
+
----------
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
softmax = nn.Softmax(dim=-1)
|
| 150 |
+
out = torch.argmax(softmax(out), dim=-1)
|
| 151 |
+
|
| 152 |
+
out = out.flatten()
|
| 153 |
+
tgt = tgt.flatten()
|
| 154 |
+
|
| 155 |
+
mask = (tgt != TOKEN_PAD)
|
| 156 |
+
|
| 157 |
+
out = out[mask]
|
| 158 |
+
tgt = tgt[mask]
|
| 159 |
+
|
| 160 |
+
# Empty
|
| 161 |
+
if(len(tgt) == 0):
|
| 162 |
+
return 1.0
|
| 163 |
+
|
| 164 |
+
num_right = (out == tgt)
|
| 165 |
+
num_right = torch.sum(num_right).type(TORCH_FLOAT)
|
| 166 |
+
|
| 167 |
+
acc = num_right / len(tgt)
|
| 168 |
+
|
| 169 |
+
return acc
|
generate.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
import pretty_midi
|
| 6 |
+
import processor
|
| 7 |
+
|
| 8 |
+
from processor import encode_midi, decode_midi
|
| 9 |
+
|
| 10 |
+
from utilities.argument_funcs import parse_generate_args, print_generate_args
|
| 11 |
+
from model.music_transformer import MusicTransformer
|
| 12 |
+
from dataset.e_piano import create_epiano_datasets, compute_epiano_accuracy, process_midi
|
| 13 |
+
from torch.utils.data import DataLoader
|
| 14 |
+
from torch.optim import Adam
|
| 15 |
+
|
| 16 |
+
from utilities.constants import *
|
| 17 |
+
from utilities.device import get_device, use_cuda
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# main
|
| 21 |
+
def main():
|
| 22 |
+
"""
|
| 23 |
+
----------
|
| 24 |
+
Author: Damon Gwinn
|
| 25 |
+
----------
|
| 26 |
+
Entry point. Generates music from a model specified by command line arguments
|
| 27 |
+
----------
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
args = parse_generate_args()
|
| 31 |
+
print_generate_args(args)
|
| 32 |
+
|
| 33 |
+
if (args.force_cpu):
|
| 34 |
+
use_cuda(False)
|
| 35 |
+
print("WARNING: Forced CPU usage, expect model to perform slower")
|
| 36 |
+
print("")
|
| 37 |
+
|
| 38 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 39 |
+
|
| 40 |
+
# --- MODIFIED LOGIC ---
|
| 41 |
+
# Can be None, an integer index to dataset, or a file path
|
| 42 |
+
if (args.primer_file is None):
|
| 43 |
+
# --- Load dataset ONLY if no primer file is given ---
|
| 44 |
+
print("No primer file provided, loading dataset to pick a random primer...")
|
| 45 |
+
_, _, dataset = create_epiano_datasets(args.midi_root, args.num_prime, random_seq=False)
|
| 46 |
+
# --- --- --- --- --- --- --- --- --- --- --- --- --- ---
|
| 47 |
+
|
| 48 |
+
f = str(random.randrange(len(dataset)))
|
| 49 |
+
idx = int(f)
|
| 50 |
+
primer, _ = dataset[idx]
|
| 51 |
+
primer = primer.to(get_device())
|
| 52 |
+
num_primer = primer.shape[0]
|
| 53 |
+
print("Using primer index:", idx, "(", dataset.data_files[idx], ")")
|
| 54 |
+
|
| 55 |
+
else:
|
| 56 |
+
# --- Primer file is provided, NO DATASET NEEDED (unless it's an index) ---
|
| 57 |
+
f = args.primer_file
|
| 58 |
+
|
| 59 |
+
# --- NEW: Check for "silence" ---
|
| 60 |
+
if (f.lower() == "silence"):
|
| 61 |
+
print("Generating from silence...")
|
| 62 |
+
# Create a primer with one token: a medium velocity (64 // 4 = 16)
|
| 63 |
+
# Velocity START_IDX = 356. Token = 356 + 16 = 372
|
| 64 |
+
primer = torch.tensor([372], dtype=TORCH_LABEL_TYPE, device=get_device())
|
| 65 |
+
num_primer = primer.shape[0] # This will be 1
|
| 66 |
+
|
| 67 |
+
# This part handles if the primer is an integer index (e.g., "3")
|
| 68 |
+
elif (f.isdigit()):
|
| 69 |
+
print("Primer file is an index, loading dataset...")
|
| 70 |
+
# --- Load dataset ONLY if primer is an index ---
|
| 71 |
+
_, _, dataset = create_epiano_datasets(args.midi_root, args.num_prime, random_seq=False)
|
| 72 |
+
# --- --- --- --- --- --- --- --- --- --- --- ---
|
| 73 |
+
idx = int(f)
|
| 74 |
+
primer, _ = dataset[idx]
|
| 75 |
+
primer = primer.to(get_device())
|
| 76 |
+
num_primer = primer.shape[0]
|
| 77 |
+
print("Using primer index:", idx, "(", dataset.data_files[idx], ")")
|
| 78 |
+
|
| 79 |
+
# This part handles if the primer is a MIDI file path (e.t., "my_primer.mid")
|
| 80 |
+
else:
|
| 81 |
+
print("Primer file is a MIDI path. Loading and tokenizing...")
|
| 82 |
+
raw_mid = encode_midi(f)
|
| 83 |
+
if (len(raw_mid) == 0):
|
| 84 |
+
print("Error: No midi messages in primer file:", f)
|
| 85 |
+
return
|
| 86 |
+
|
| 87 |
+
primer, _ = process_midi(raw_mid, args.num_prime, random_seq=False)
|
| 88 |
+
primer = torch.tensor(primer, dtype=TORCH_LABEL_TYPE, device=get_device())
|
| 89 |
+
num_primer = primer.shape[0] # Get the actual primer length
|
| 90 |
+
print("Using primer file:", f)
|
| 91 |
+
|
| 92 |
+
# --- END MODIFIED LOGIC ---
|
| 93 |
+
|
| 94 |
+
model = MusicTransformer(n_layers=args.n_layers, num_heads=args.num_heads,
|
| 95 |
+
d_model=args.d_model, dim_feedforward=args.dim_feedforward,
|
| 96 |
+
max_sequence=args.max_sequence, rpr=args.rpr).to(get_device())
|
| 97 |
+
|
| 98 |
+
model.load_state_dict(torch.load(args.model_weights))
|
| 99 |
+
|
| 100 |
+
# --- MODIFICATION: Don't save a primer if we started from silence ---
|
| 101 |
+
if (args.primer_file.lower() != "silence"):
|
| 102 |
+
f_path = os.path.join(args.output_dir, "primer.mid")
|
| 103 |
+
decode_midi(primer.cpu().numpy(), file_path=f_path)
|
| 104 |
+
# --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- ---
|
| 105 |
+
|
| 106 |
+
# GENERATION
|
| 107 |
+
model.eval()
|
| 108 |
+
with torch.set_grad_enabled(False):
|
| 109 |
+
if (args.beam > 0):
|
| 110 |
+
print("BEAM:", args.beam)
|
| 111 |
+
|
| 112 |
+
beam_seq = model.generate(primer.unsqueeze(0), args.target_seq_length, beam=args.beam)
|
| 113 |
+
|
| 114 |
+
# --- MODIFICATION: Slice the primer ---
|
| 115 |
+
generated_only = beam_seq[0][num_primer:]
|
| 116 |
+
# ------------------------------------
|
| 117 |
+
|
| 118 |
+
f_path = os.path.join(args.output_dir, "beam.mid")
|
| 119 |
+
decode_midi(generated_only.cpu().numpy(), file_path=f_path)
|
| 120 |
+
else:
|
| 121 |
+
print("RAND DIST")
|
| 122 |
+
|
| 123 |
+
rand_seq = model.generate(primer.unsqueeze(0), args.target_seq_length, beam=0)
|
| 124 |
+
|
| 125 |
+
# --- MODIFICATION: Slice the primer ---
|
| 126 |
+
generated_only = rand_seq[0][num_primer:]
|
| 127 |
+
# ------------------------------------
|
| 128 |
+
|
| 129 |
+
f_path = os.path.join(args.output_dir, "rand.mid")
|
| 130 |
+
decode_midi(generated_only.cpu().numpy(), file_path=f_path)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
if __name__ == "__main__":
|
| 134 |
+
main()
|
model/__init__.py
ADDED
|
File without changes
|
model/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (151 Bytes). View file
|
|
|
model/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (155 Bytes). View file
|
|
|
model/__pycache__/music_transformer.cpython-312.pyc
ADDED
|
Binary file (8.13 kB). View file
|
|
|
model/__pycache__/music_transformer.cpython-313.pyc
ADDED
|
Binary file (6.56 kB). View file
|
|
|
model/__pycache__/positional_encoding.cpython-312.pyc
ADDED
|
Binary file (2.05 kB). View file
|
|
|
model/__pycache__/positional_encoding.cpython-313.pyc
ADDED
|
Binary file (2.11 kB). View file
|
|
|
model/__pycache__/rpr.cpython-312.pyc
ADDED
|
Binary file (19.3 kB). View file
|
|
|
model/__pycache__/rpr.cpython-313.pyc
ADDED
|
Binary file (11.3 kB). View file
|
|
|
model/music_transformer.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.nn.modules.normalization import LayerNorm
|
| 4 |
+
import random
|
| 5 |
+
|
| 6 |
+
from utilities.constants import *
|
| 7 |
+
from utilities.device import get_device
|
| 8 |
+
|
| 9 |
+
from .positional_encoding import PositionalEncoding
|
| 10 |
+
from .rpr import TransformerEncoderRPR, TransformerEncoderLayerRPR
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# MusicTransformer
|
| 14 |
+
class MusicTransformer(nn.Module):
|
| 15 |
+
def __init__(self, n_layers=6, num_heads=8, d_model=512, dim_feedforward=1024,
|
| 16 |
+
dropout=0.1, max_sequence=2048, rpr=False):
|
| 17 |
+
super(MusicTransformer, self).__init__()
|
| 18 |
+
|
| 19 |
+
self.dummy = DummyDecoder()
|
| 20 |
+
|
| 21 |
+
self.nlayers = n_layers
|
| 22 |
+
self.nhead = num_heads
|
| 23 |
+
self.d_model = d_model
|
| 24 |
+
self.d_ff = dim_feedforward
|
| 25 |
+
self.dropout = dropout
|
| 26 |
+
self.max_seq = max_sequence
|
| 27 |
+
self.rpr = rpr
|
| 28 |
+
|
| 29 |
+
# Input embedding
|
| 30 |
+
self.embedding = nn.Embedding(VOCAB_SIZE, self.d_model)
|
| 31 |
+
|
| 32 |
+
# Positional encoding
|
| 33 |
+
self.positional_encoding = PositionalEncoding(self.d_model, self.dropout, self.max_seq)
|
| 34 |
+
|
| 35 |
+
# Base transformer
|
| 36 |
+
if(not self.rpr):
|
| 37 |
+
self.transformer = nn.Transformer(
|
| 38 |
+
d_model=self.d_model, nhead=self.nhead, num_encoder_layers=self.nlayers,
|
| 39 |
+
num_decoder_layers=0, dropout=self.dropout,
|
| 40 |
+
dim_feedforward=self.d_ff, custom_decoder=self.dummy
|
| 41 |
+
)
|
| 42 |
+
else:
|
| 43 |
+
encoder_norm = LayerNorm(self.d_model)
|
| 44 |
+
encoder_layer = TransformerEncoderLayerRPR(self.d_model, self.nhead, self.d_ff, self.dropout, er_len=self.max_seq)
|
| 45 |
+
encoder = TransformerEncoderRPR(encoder_layer, self.nlayers, encoder_norm)
|
| 46 |
+
self.transformer = nn.Transformer(
|
| 47 |
+
d_model=self.d_model, nhead=self.nhead, num_encoder_layers=self.nlayers,
|
| 48 |
+
num_decoder_layers=0, dropout=self.dropout,
|
| 49 |
+
dim_feedforward=self.d_ff, custom_decoder=self.dummy, custom_encoder=encoder
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# Final output is a softmaxed linear layer
|
| 53 |
+
self.Wout = nn.Linear(self.d_model, VOCAB_SIZE)
|
| 54 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 55 |
+
|
| 56 |
+
# forward
|
| 57 |
+
def forward(self, x, mask=True):
|
| 58 |
+
# --- FIX: USE DEVICE OF INPUT TENSOR x ---
|
| 59 |
+
if(mask is True):
|
| 60 |
+
# Generate mask on the same device as input x
|
| 61 |
+
mask = self.transformer.generate_square_subsequent_mask(x.shape[1]).to(x.device)
|
| 62 |
+
else:
|
| 63 |
+
mask = None
|
| 64 |
+
# -----------------------------------------
|
| 65 |
+
|
| 66 |
+
x = self.embedding(x)
|
| 67 |
+
|
| 68 |
+
# Input shape is (max_seq, batch_size, d_model)
|
| 69 |
+
x = x.permute(1,0,2)
|
| 70 |
+
|
| 71 |
+
x = self.positional_encoding(x)
|
| 72 |
+
|
| 73 |
+
# Since there are no true decoder layers, the tgt is unused
|
| 74 |
+
x_out = self.transformer(src=x, tgt=x, src_mask=mask)
|
| 75 |
+
|
| 76 |
+
# Back to (batch_size, max_seq, d_model)
|
| 77 |
+
x_out = x_out.permute(1,0,2)
|
| 78 |
+
|
| 79 |
+
y = self.Wout(x_out)
|
| 80 |
+
return y
|
| 81 |
+
|
| 82 |
+
# generate
|
| 83 |
+
def generate(self, primer=None, target_seq_length=1024, beam=0, beam_chance=1.0):
|
| 84 |
+
assert (not self.training), "Cannot generate while in training mode"
|
| 85 |
+
|
| 86 |
+
print("Generating sequence of max length:", target_seq_length)
|
| 87 |
+
|
| 88 |
+
batch_size = primer.shape[0]
|
| 89 |
+
gen_seq = torch.full((batch_size, target_seq_length), TOKEN_PAD, dtype=TORCH_LABEL_TYPE, device=get_device())
|
| 90 |
+
|
| 91 |
+
num_primer = primer.shape[1]
|
| 92 |
+
gen_seq[..., :num_primer] = primer.type(TORCH_LABEL_TYPE).to(get_device())
|
| 93 |
+
|
| 94 |
+
cur_i = num_primer
|
| 95 |
+
while(cur_i < target_seq_length):
|
| 96 |
+
y = self.softmax(self.forward(gen_seq[..., :cur_i]))[..., :TOKEN_END]
|
| 97 |
+
token_probs = y[:, cur_i-1, :]
|
| 98 |
+
|
| 99 |
+
if(beam == 0):
|
| 100 |
+
beam_ran = 2.0
|
| 101 |
+
else:
|
| 102 |
+
beam_ran = random.uniform(0,1)
|
| 103 |
+
|
| 104 |
+
if(beam_ran <= beam_chance):
|
| 105 |
+
token_probs = token_probs.flatten()
|
| 106 |
+
top_res, top_i = torch.topk(token_probs, beam)
|
| 107 |
+
|
| 108 |
+
beam_rows = top_i // VOCAB_SIZE
|
| 109 |
+
beam_cols = top_i % VOCAB_SIZE
|
| 110 |
+
|
| 111 |
+
gen_seq = gen_seq[beam_rows, :]
|
| 112 |
+
gen_seq[..., cur_i] = beam_cols
|
| 113 |
+
|
| 114 |
+
else:
|
| 115 |
+
distrib = torch.distributions.categorical.Categorical(probs=token_probs)
|
| 116 |
+
next_token = distrib.sample()
|
| 117 |
+
gen_seq[:, cur_i] = next_token
|
| 118 |
+
|
| 119 |
+
if(next_token == TOKEN_END):
|
| 120 |
+
print("Model called end of sequence at:", cur_i, "/", target_seq_length)
|
| 121 |
+
break
|
| 122 |
+
|
| 123 |
+
cur_i += 1
|
| 124 |
+
if(cur_i % 50 == 0):
|
| 125 |
+
print(cur_i, "/", target_seq_length)
|
| 126 |
+
|
| 127 |
+
return gen_seq[:, :cur_i]
|
| 128 |
+
|
| 129 |
+
# Used as a dummy to nn.Transformer
|
| 130 |
+
class DummyDecoder(nn.Module):
|
| 131 |
+
def __init__(self):
|
| 132 |
+
super(DummyDecoder, self).__init__()
|
| 133 |
+
|
| 134 |
+
def forward(self, tgt, memory, tgt_mask, memory_mask,tgt_key_padding_mask,memory_key_padding_mask, **kwargs):
|
| 135 |
+
return memory
|
model/positional_encoding.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
# PositionalEncoding
|
| 6 |
+
# Taken from https://pytorch.org/tutorials/beginner/transformer_tutorial.html
|
| 7 |
+
class PositionalEncoding(nn.Module):
|
| 8 |
+
|
| 9 |
+
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
| 10 |
+
super(PositionalEncoding, self).__init__()
|
| 11 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 12 |
+
|
| 13 |
+
pe = torch.zeros(max_len, d_model)
|
| 14 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 15 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
| 16 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 17 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 18 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
| 19 |
+
self.register_buffer('pe', pe)
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
x = x + self.pe[:x.size(0), :]
|
| 23 |
+
return self.dropout(x)
|
model/rpr.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
from torch.nn.parameter import Parameter
|
| 5 |
+
from torch.nn import Module
|
| 6 |
+
from torch.nn.modules.linear import Linear
|
| 7 |
+
from torch.nn.modules.dropout import Dropout
|
| 8 |
+
from torch.nn.modules.normalization import LayerNorm
|
| 9 |
+
from torch.nn.init import *
|
| 10 |
+
|
| 11 |
+
class TransformerEncoderRPR(Module):
|
| 12 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
| 13 |
+
super(TransformerEncoderRPR, self).__init__()
|
| 14 |
+
self.layers = torch.nn.ModuleList([encoder_layer for _ in range(num_layers)]) # Fix for tracing
|
| 15 |
+
self.num_layers = num_layers
|
| 16 |
+
self.norm = norm
|
| 17 |
+
|
| 18 |
+
def forward(self, src, mask=None, src_key_padding_mask=None, **kwargs):
|
| 19 |
+
output = src
|
| 20 |
+
for layer in self.layers:
|
| 21 |
+
output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
|
| 22 |
+
if self.norm:
|
| 23 |
+
output = self.norm(output)
|
| 24 |
+
return output
|
| 25 |
+
|
| 26 |
+
class TransformerEncoderLayerRPR(Module):
|
| 27 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, er_len=None):
|
| 28 |
+
super(TransformerEncoderLayerRPR, self).__init__()
|
| 29 |
+
self.self_attn = MultiheadAttentionRPR(d_model, nhead, dropout=dropout, er_len=er_len)
|
| 30 |
+
self.linear1 = Linear(d_model, dim_feedforward)
|
| 31 |
+
self.dropout = Dropout(dropout)
|
| 32 |
+
self.linear2 = Linear(dim_feedforward, d_model)
|
| 33 |
+
self.norm1 = LayerNorm(d_model)
|
| 34 |
+
self.norm2 = LayerNorm(d_model)
|
| 35 |
+
self.dropout1 = Dropout(dropout)
|
| 36 |
+
self.dropout2 = Dropout(dropout)
|
| 37 |
+
|
| 38 |
+
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
| 39 |
+
src2 = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
|
| 40 |
+
src = src + self.dropout1(src2)
|
| 41 |
+
src = self.norm1(src)
|
| 42 |
+
src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))
|
| 43 |
+
src = src + self.dropout2(src2)
|
| 44 |
+
src = self.norm2(src)
|
| 45 |
+
return src
|
| 46 |
+
|
| 47 |
+
class MultiheadAttentionRPR(Module):
|
| 48 |
+
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, er_len=None):
|
| 49 |
+
super(MultiheadAttentionRPR, self).__init__()
|
| 50 |
+
self.embed_dim = embed_dim
|
| 51 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
| 52 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
| 53 |
+
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
| 54 |
+
self.num_heads = num_heads
|
| 55 |
+
self.dropout = dropout
|
| 56 |
+
self.head_dim = embed_dim // num_heads
|
| 57 |
+
|
| 58 |
+
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
|
| 59 |
+
if not self._qkv_same_embed_dim:
|
| 60 |
+
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
|
| 61 |
+
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
|
| 62 |
+
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
|
| 63 |
+
|
| 64 |
+
if bias:
|
| 65 |
+
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
|
| 66 |
+
else:
|
| 67 |
+
self.register_parameter('in_proj_bias', None)
|
| 68 |
+
self.out_proj = Linear(embed_dim, embed_dim, bias=bias)
|
| 69 |
+
self.add_zero_attn = add_zero_attn
|
| 70 |
+
|
| 71 |
+
if er_len is not None:
|
| 72 |
+
self.Er = Parameter(torch.rand((er_len, self.head_dim), dtype=torch.float32))
|
| 73 |
+
else:
|
| 74 |
+
self.Er = None
|
| 75 |
+
self._reset_parameters()
|
| 76 |
+
|
| 77 |
+
def _reset_parameters(self):
|
| 78 |
+
if self._qkv_same_embed_dim: xavier_uniform_(self.in_proj_weight)
|
| 79 |
+
else:
|
| 80 |
+
xavier_uniform_(self.q_proj_weight)
|
| 81 |
+
xavier_uniform_(self.k_proj_weight)
|
| 82 |
+
xavier_uniform_(self.v_proj_weight)
|
| 83 |
+
if self.in_proj_bias is not None:
|
| 84 |
+
constant_(self.in_proj_bias, 0.)
|
| 85 |
+
constant_(self.out_proj.bias, 0.)
|
| 86 |
+
|
| 87 |
+
def forward(self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None):
|
| 88 |
+
return multi_head_attention_forward_rpr(
|
| 89 |
+
query, key, value, self.embed_dim, self.num_heads, self.head_dim,
|
| 90 |
+
self.in_proj_weight, self.in_proj_bias,
|
| 91 |
+
None, None, self.add_zero_attn,
|
| 92 |
+
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
| 93 |
+
training=self.training,
|
| 94 |
+
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
| 95 |
+
attn_mask=attn_mask, use_separate_proj_weight=not self._qkv_same_embed_dim,
|
| 96 |
+
q_proj_weight=getattr(self, 'q_proj_weight', None),
|
| 97 |
+
k_proj_weight=getattr(self, 'k_proj_weight', None),
|
| 98 |
+
v_proj_weight=getattr(self, 'v_proj_weight', None),
|
| 99 |
+
rpr_mat=self.Er)
|
| 100 |
+
|
| 101 |
+
def multi_head_attention_forward_rpr(query, key, value, embed_dim_to_check, num_heads, head_dim,
|
| 102 |
+
in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn,
|
| 103 |
+
dropout_p, out_proj_weight, out_proj_bias, training=True,
|
| 104 |
+
key_padding_mask=None, need_weights=True, attn_mask=None,
|
| 105 |
+
use_separate_proj_weight=False, q_proj_weight=None,
|
| 106 |
+
k_proj_weight=None, v_proj_weight=None, static_k=None,
|
| 107 |
+
static_v=None, rpr_mat=None):
|
| 108 |
+
|
| 109 |
+
tgt_len, bsz, embed_dim = query.size()
|
| 110 |
+
scaling = float(head_dim) ** -0.5
|
| 111 |
+
|
| 112 |
+
if not use_separate_proj_weight:
|
| 113 |
+
q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
|
| 114 |
+
else:
|
| 115 |
+
q = F.linear(query, q_proj_weight, in_proj_bias[0:embed_dim])
|
| 116 |
+
k = F.linear(key, k_proj_weight, in_proj_bias[embed_dim:(embed_dim * 2)])
|
| 117 |
+
v = F.linear(value, v_proj_weight, in_proj_bias[(embed_dim * 2):])
|
| 118 |
+
|
| 119 |
+
q = q * scaling
|
| 120 |
+
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
| 121 |
+
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
| 122 |
+
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
| 123 |
+
|
| 124 |
+
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
|
| 125 |
+
|
| 126 |
+
if rpr_mat is not None:
|
| 127 |
+
# Safe Explicit Skew
|
| 128 |
+
len_q = q.shape[1]
|
| 129 |
+
start_idx = rpr_mat.shape[0] - len_q
|
| 130 |
+
rpr_mat_valid = rpr_mat[start_idx:, :]
|
| 131 |
+
qe = torch.einsum("hld,md->hlm", q, rpr_mat_valid)
|
| 132 |
+
|
| 133 |
+
# Indices logic (Flatten -> Gather -> Reshape)
|
| 134 |
+
B, L, _ = qe.shape
|
| 135 |
+
# Mask out upper triangle BEFORE skewing
|
| 136 |
+
mask_tri = torch.triu(torch.ones((L, L), device=qe.device, dtype=torch.bool)).flip(0)
|
| 137 |
+
qe = qe.masked_fill(~mask_tri, 0.0) # Fill with 0 before shift
|
| 138 |
+
|
| 139 |
+
zeros = torch.zeros((B, L, 1), device=qe.device, dtype=qe.dtype)
|
| 140 |
+
qe_pad = torch.cat([zeros, qe], dim=2).view(B, -1)
|
| 141 |
+
|
| 142 |
+
offsets = torch.arange(L * L, device=qe.device, dtype=torch.int64) + L
|
| 143 |
+
offsets = offsets.unsqueeze(0).expand(B, -1)
|
| 144 |
+
srel = torch.gather(qe_pad, 1, offsets).view(B, L, L)
|
| 145 |
+
|
| 146 |
+
attn_output_weights = attn_output_weights + srel
|
| 147 |
+
|
| 148 |
+
# --- MASKING FIX (Boolean Masked Fill) ---
|
| 149 |
+
if attn_mask is not None:
|
| 150 |
+
if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0)
|
| 151 |
+
# ONNX prefers masked_fill with boolean mask over adding -inf
|
| 152 |
+
is_causal_mask = (attn_mask == float('-inf')) | (attn_mask < -1e4)
|
| 153 |
+
attn_output_weights = attn_output_weights.masked_fill(is_causal_mask, float('-inf'))
|
| 154 |
+
|
| 155 |
+
if key_padding_mask is not None:
|
| 156 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, tgt_len)
|
| 157 |
+
attn_output_weights = attn_output_weights.masked_fill(
|
| 158 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf')
|
| 159 |
+
)
|
| 160 |
+
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, tgt_len)
|
| 161 |
+
|
| 162 |
+
attn_output_weights = F.softmax(attn_output_weights, dim=-1)
|
| 163 |
+
attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training)
|
| 164 |
+
|
| 165 |
+
attn_output = torch.bmm(attn_output_weights, v)
|
| 166 |
+
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
| 167 |
+
attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
|
| 168 |
+
|
| 169 |
+
if need_weights:
|
| 170 |
+
return attn_output, attn_output_weights.view(bsz, num_heads, tgt_len, tgt_len).sum(dim=1) / num_heads
|
| 171 |
+
return attn_output, None
|
processor.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pretty_midi
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
RANGE_NOTE_ON = 128
|
| 5 |
+
RANGE_NOTE_OFF = 128
|
| 6 |
+
RANGE_VEL = 32
|
| 7 |
+
RANGE_TIME_SHIFT = 100
|
| 8 |
+
|
| 9 |
+
START_IDX = {
|
| 10 |
+
'note_on': 0,
|
| 11 |
+
'note_off': RANGE_NOTE_ON,
|
| 12 |
+
'time_shift': RANGE_NOTE_ON + RANGE_NOTE_OFF,
|
| 13 |
+
'velocity': RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_TIME_SHIFT
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class SustainAdapter:
|
| 18 |
+
def __init__(self, time, type):
|
| 19 |
+
self.start = time
|
| 20 |
+
self.type = type
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class SustainDownManager:
|
| 24 |
+
def __init__(self, start, end):
|
| 25 |
+
self.start = start
|
| 26 |
+
self.end = end
|
| 27 |
+
self.managed_notes = []
|
| 28 |
+
self._note_dict = {} # key: pitch, value: note.start
|
| 29 |
+
|
| 30 |
+
def add_managed_note(self, note: pretty_midi.Note):
|
| 31 |
+
self.managed_notes.append(note)
|
| 32 |
+
|
| 33 |
+
def transposition_notes(self):
|
| 34 |
+
for note in reversed(self.managed_notes):
|
| 35 |
+
try:
|
| 36 |
+
note.end = self._note_dict[note.pitch]
|
| 37 |
+
except KeyError:
|
| 38 |
+
note.end = max(self.end, note.end)
|
| 39 |
+
self._note_dict[note.pitch] = note.start
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# Divided note by note_on, note_off
|
| 43 |
+
class SplitNote:
|
| 44 |
+
def __init__(self, type, time, value, velocity):
|
| 45 |
+
## type: note_on, note_off
|
| 46 |
+
self.type = type
|
| 47 |
+
self.time = time
|
| 48 |
+
self.velocity = velocity
|
| 49 |
+
self.value = value
|
| 50 |
+
|
| 51 |
+
def __repr__(self):
|
| 52 |
+
return '<[SNote] time: {} type: {}, value: {}, velocity: {}>'\
|
| 53 |
+
.format(self.time, self.type, self.value, self.velocity)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class Event:
|
| 57 |
+
def __init__(self, event_type, value):
|
| 58 |
+
self.type = event_type
|
| 59 |
+
self.value = value
|
| 60 |
+
|
| 61 |
+
def __repr__(self):
|
| 62 |
+
return '<Event type: {}, value: {}>'.format(self.type, self.value)
|
| 63 |
+
|
| 64 |
+
def to_int(self):
|
| 65 |
+
return START_IDX[self.type] + self.value
|
| 66 |
+
|
| 67 |
+
@staticmethod
|
| 68 |
+
def from_int(int_value):
|
| 69 |
+
info = Event._type_check(int_value)
|
| 70 |
+
return Event(info['type'], info['value'])
|
| 71 |
+
|
| 72 |
+
@staticmethod
|
| 73 |
+
def _type_check(int_value):
|
| 74 |
+
range_note_on = range(0, RANGE_NOTE_ON)
|
| 75 |
+
range_note_off = range(RANGE_NOTE_ON, RANGE_NOTE_ON+RANGE_NOTE_OFF)
|
| 76 |
+
range_time_shift = range(RANGE_NOTE_ON+RANGE_NOTE_OFF,RANGE_NOTE_ON+RANGE_NOTE_OFF+RANGE_TIME_SHIFT)
|
| 77 |
+
|
| 78 |
+
valid_value = int_value
|
| 79 |
+
|
| 80 |
+
if int_value in range_note_on:
|
| 81 |
+
return {'type': 'note_on', 'value': valid_value}
|
| 82 |
+
elif int_value in range_note_off:
|
| 83 |
+
valid_value -= RANGE_NOTE_ON
|
| 84 |
+
return {'type': 'note_off', 'value': valid_value}
|
| 85 |
+
elif int_value in range_time_shift:
|
| 86 |
+
valid_value -= (RANGE_NOTE_ON + RANGE_NOTE_OFF)
|
| 87 |
+
return {'type': 'time_shift', 'value': valid_value}
|
| 88 |
+
else:
|
| 89 |
+
valid_value -= (RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_TIME_SHIFT)
|
| 90 |
+
return {'type': 'velocity', 'value': valid_value}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _divide_note(notes):
|
| 94 |
+
result_array = []
|
| 95 |
+
notes.sort(key=lambda x: x.start)
|
| 96 |
+
|
| 97 |
+
for note in notes:
|
| 98 |
+
on = SplitNote('note_on', note.start, note.pitch, note.velocity)
|
| 99 |
+
off = SplitNote('note_off', note.end, note.pitch, None)
|
| 100 |
+
result_array += [on, off]
|
| 101 |
+
return result_array
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _merge_note(snote_sequence):
|
| 105 |
+
note_on_dict = {}
|
| 106 |
+
result_array = []
|
| 107 |
+
|
| 108 |
+
for snote in snote_sequence:
|
| 109 |
+
# print(note_on_dict)
|
| 110 |
+
if snote.type == 'note_on':
|
| 111 |
+
note_on_dict[snote.value] = snote
|
| 112 |
+
elif snote.type == 'note_off':
|
| 113 |
+
try:
|
| 114 |
+
on = note_on_dict[snote.value]
|
| 115 |
+
off = snote
|
| 116 |
+
if off.time - on.time == 0:
|
| 117 |
+
continue
|
| 118 |
+
result = pretty_midi.Note(on.velocity, snote.value, on.time, off.time)
|
| 119 |
+
result_array.append(result)
|
| 120 |
+
except:
|
| 121 |
+
print('info removed pitch: {}'.format(snote.value))
|
| 122 |
+
return result_array
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _snote2events(snote: SplitNote, prev_vel: int):
|
| 126 |
+
result = []
|
| 127 |
+
if snote.velocity is not None:
|
| 128 |
+
modified_velocity = snote.velocity // 4
|
| 129 |
+
if prev_vel != modified_velocity:
|
| 130 |
+
result.append(Event(event_type='velocity', value=modified_velocity))
|
| 131 |
+
result.append(Event(event_type=snote.type, value=snote.value))
|
| 132 |
+
return result
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _event_seq2snote_seq(event_sequence):
|
| 136 |
+
timeline = 0
|
| 137 |
+
velocity = 0
|
| 138 |
+
snote_seq = []
|
| 139 |
+
|
| 140 |
+
for event in event_sequence:
|
| 141 |
+
if event.type == 'time_shift':
|
| 142 |
+
timeline += ((event.value+1) / 100)
|
| 143 |
+
if event.type == 'velocity':
|
| 144 |
+
velocity = event.value * 4
|
| 145 |
+
else:
|
| 146 |
+
snote = SplitNote(event.type, timeline, event.value, velocity)
|
| 147 |
+
snote_seq.append(snote)
|
| 148 |
+
return snote_seq
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def _make_time_sift_events(prev_time, post_time):
|
| 152 |
+
time_interval = int(round((post_time - prev_time) * 100))
|
| 153 |
+
results = []
|
| 154 |
+
while time_interval >= RANGE_TIME_SHIFT:
|
| 155 |
+
results.append(Event(event_type='time_shift', value=RANGE_TIME_SHIFT-1))
|
| 156 |
+
time_interval -= RANGE_TIME_SHIFT
|
| 157 |
+
if time_interval == 0:
|
| 158 |
+
return results
|
| 159 |
+
else:
|
| 160 |
+
return results + [Event(event_type='time_shift', value=time_interval-1)]
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def _control_preprocess(ctrl_changes):
|
| 164 |
+
sustains = []
|
| 165 |
+
|
| 166 |
+
manager = None
|
| 167 |
+
for ctrl in ctrl_changes:
|
| 168 |
+
if ctrl.value >= 64 and manager is None:
|
| 169 |
+
# sustain down
|
| 170 |
+
manager = SustainDownManager(start=ctrl.time, end=None)
|
| 171 |
+
elif ctrl.value < 64 and manager is not None:
|
| 172 |
+
# sustain up
|
| 173 |
+
manager.end = ctrl.time
|
| 174 |
+
sustains.append(manager)
|
| 175 |
+
manager = None
|
| 176 |
+
elif ctrl.value < 64 and len(sustains) > 0:
|
| 177 |
+
sustains[-1].end = ctrl.time
|
| 178 |
+
return sustains
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def _note_preprocess(susteins, notes):
|
| 182 |
+
note_stream = []
|
| 183 |
+
|
| 184 |
+
if susteins: # if the midi file has sustain controls
|
| 185 |
+
for sustain in susteins:
|
| 186 |
+
for note_idx, note in enumerate(notes):
|
| 187 |
+
if note.start < sustain.start:
|
| 188 |
+
note_stream.append(note)
|
| 189 |
+
elif note.start > sustain.end:
|
| 190 |
+
notes = notes[note_idx:]
|
| 191 |
+
sustain.transposition_notes()
|
| 192 |
+
break
|
| 193 |
+
else:
|
| 194 |
+
sustain.add_managed_note(note)
|
| 195 |
+
|
| 196 |
+
for sustain in susteins:
|
| 197 |
+
note_stream += sustain.managed_notes
|
| 198 |
+
|
| 199 |
+
else: # else, just push everything into note stream
|
| 200 |
+
for note_idx, note in enumerate(notes):
|
| 201 |
+
note_stream.append(note)
|
| 202 |
+
|
| 203 |
+
note_stream.sort(key= lambda x: x.start)
|
| 204 |
+
return note_stream
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def encode_midi(file_path):
|
| 208 |
+
events = []
|
| 209 |
+
notes = []
|
| 210 |
+
mid = pretty_midi.PrettyMIDI(midi_file=file_path)
|
| 211 |
+
|
| 212 |
+
for inst in mid.instruments:
|
| 213 |
+
inst_notes = inst.notes
|
| 214 |
+
# ctrl.number is the number of sustain control. If you want to know abour the number type of control,
|
| 215 |
+
# see https://www.midi.org/specifications-old/item/table-3-control-change-messages-data-bytes-2
|
| 216 |
+
ctrls = _control_preprocess([ctrl for ctrl in inst.control_changes if ctrl.number == 64])
|
| 217 |
+
notes += _note_preprocess(ctrls, inst_notes)
|
| 218 |
+
|
| 219 |
+
dnotes = _divide_note(notes)
|
| 220 |
+
|
| 221 |
+
# print(dnotes)
|
| 222 |
+
dnotes.sort(key=lambda x: x.time)
|
| 223 |
+
# print('sorted:')
|
| 224 |
+
# print(dnotes)
|
| 225 |
+
cur_time = 0
|
| 226 |
+
cur_vel = 0
|
| 227 |
+
for snote in dnotes:
|
| 228 |
+
events += _make_time_sift_events(prev_time=cur_time, post_time=snote.time)
|
| 229 |
+
events += _snote2events(snote=snote, prev_vel=cur_vel)
|
| 230 |
+
# events += _make_time_sift_events(prev_time=cur_time, post_time=snote.time)
|
| 231 |
+
|
| 232 |
+
cur_time = snote.time
|
| 233 |
+
cur_vel = snote.velocity
|
| 234 |
+
|
| 235 |
+
return [e.to_int() for e in events]
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def decode_midi(idx_array, file_path=None):
|
| 239 |
+
event_sequence = [Event.from_int(idx) for idx in idx_array]
|
| 240 |
+
# print(event_sequence)
|
| 241 |
+
snote_seq = _event_seq2snote_seq(event_sequence)
|
| 242 |
+
note_seq = _merge_note(snote_seq)
|
| 243 |
+
note_seq.sort(key=lambda x:x.start)
|
| 244 |
+
|
| 245 |
+
mid = pretty_midi.PrettyMIDI()
|
| 246 |
+
# if want to change instument, see https://www.midi.org/specifications/item/gm-level-1-sound-set
|
| 247 |
+
instument = pretty_midi.Instrument(0, False, "Composed by Super Piano Music Transformer AI")
|
| 248 |
+
instument.notes = note_seq
|
| 249 |
+
|
| 250 |
+
mid.instruments.append(instument)
|
| 251 |
+
if file_path is not None:
|
| 252 |
+
mid.write(file_path)
|
| 253 |
+
return mid
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
if __name__ == '__main__':
|
| 257 |
+
encoded = encode_midi('bin/ADIG04.mid')
|
| 258 |
+
print(encoded)
|
| 259 |
+
decided = decode_midi(encoded,file_path='bin/test.mid')
|
| 260 |
+
|
| 261 |
+
ins = pretty_midi.PrettyMIDI('bin/ADIG04.mid')
|
| 262 |
+
print(ins)
|
| 263 |
+
print(ins.instruments[0])
|
| 264 |
+
for i in ins.instruments:
|
| 265 |
+
print(i.control_changes)
|
| 266 |
+
print(i.notes)
|
| 267 |
+
|
utilities/__init__.py
ADDED
|
File without changes
|
utilities/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (155 Bytes). View file
|
|
|
utilities/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (159 Bytes). View file
|
|
|
utilities/__pycache__/argument_funcs.cpython-312.pyc
ADDED
|
Binary file (13.4 kB). View file
|
|
|
utilities/__pycache__/constants.cpython-312.pyc
ADDED
|
Binary file (840 Bytes). View file
|
|
|
utilities/__pycache__/constants.cpython-313.pyc
ADDED
|
Binary file (844 Bytes). View file
|
|
|
utilities/__pycache__/device.cpython-312.pyc
ADDED
|
Binary file (1.74 kB). View file
|
|
|
utilities/__pycache__/device.cpython-313.pyc
ADDED
|
Binary file (1.66 kB). View file
|
|
|
utilities/argument_funcs.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
from .constants import SEPERATOR
|
| 4 |
+
|
| 5 |
+
# parse_train_args
|
| 6 |
+
def parse_train_args():
|
| 7 |
+
"""
|
| 8 |
+
----------
|
| 9 |
+
Author: Damon Gwinn
|
| 10 |
+
----------
|
| 11 |
+
Argparse arguments for training a model
|
| 12 |
+
----------
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
parser = argparse.ArgumentParser()
|
| 16 |
+
|
| 17 |
+
parser.add_argument("-input_dir", type=str, default="./dataset/e_piano", help="Folder of preprocessed and pickled midi files")
|
| 18 |
+
parser.add_argument("-output_dir", type=str, default="./saved_models", help="Folder to save model weights. Saves one every epoch")
|
| 19 |
+
parser.add_argument("-weight_modulus", type=int, default=1, help="How often to save epoch weights (ex: value of 10 means save every 10 epochs)")
|
| 20 |
+
parser.add_argument("-print_modulus", type=int, default=1, help="How often to print train results for a batch (batch loss, learn rate, etc.)")
|
| 21 |
+
|
| 22 |
+
parser.add_argument("-n_workers", type=int, default=1, help="Number of threads for the dataloader")
|
| 23 |
+
parser.add_argument("--force_cpu", action="store_true", help="Forces model to run on a cpu even when gpu is available")
|
| 24 |
+
parser.add_argument("--no_tensorboard", action="store_true", help="Turns off tensorboard result reporting")
|
| 25 |
+
|
| 26 |
+
parser.add_argument("-continue_weights", type=str, default=None, help="Model weights to continue training based on")
|
| 27 |
+
parser.add_argument("-continue_epoch", type=int, default=None, help="Epoch the continue_weights model was at")
|
| 28 |
+
|
| 29 |
+
parser.add_argument("-lr", type=float, default=None, help="Constant learn rate. Leave as None for a custom scheduler.")
|
| 30 |
+
parser.add_argument("-ce_smoothing", type=float, default=None, help="Smoothing parameter for smoothed cross entropy loss (defaults to no smoothing)")
|
| 31 |
+
parser.add_argument("-batch_size", type=int, default=2, help="Batch size to use")
|
| 32 |
+
parser.add_argument("-epochs", type=int, default=100, help="Number of epochs to use")
|
| 33 |
+
|
| 34 |
+
parser.add_argument("--rpr", action="store_true", help="Use a modified Transformer for Relative Position Representations")
|
| 35 |
+
parser.add_argument("-max_sequence", type=int, default=2048, help="Maximum midi sequence to consider")
|
| 36 |
+
parser.add_argument("-n_layers", type=int, default=6, help="Number of decoder layers to use")
|
| 37 |
+
parser.add_argument("-num_heads", type=int, default=8, help="Number of heads to use for multi-head attention")
|
| 38 |
+
parser.add_argument("-d_model", type=int, default=512, help="Dimension of the model (output dim of embedding layers, etc.)")
|
| 39 |
+
|
| 40 |
+
parser.add_argument("-dim_feedforward", type=int, default=1024, help="Dimension of the feedforward layer")
|
| 41 |
+
|
| 42 |
+
parser.add_argument("-dropout", type=float, default=0.1, help="Dropout rate")
|
| 43 |
+
|
| 44 |
+
return parser.parse_args()
|
| 45 |
+
|
| 46 |
+
# print_train_args
|
| 47 |
+
def print_train_args(args):
|
| 48 |
+
"""
|
| 49 |
+
----------
|
| 50 |
+
Author: Damon Gwinn
|
| 51 |
+
----------
|
| 52 |
+
Prints training arguments
|
| 53 |
+
----------
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
print(SEPERATOR)
|
| 57 |
+
print("input_dir:", args.input_dir)
|
| 58 |
+
print("output_dir:", args.output_dir)
|
| 59 |
+
print("weight_modulus:", args.weight_modulus)
|
| 60 |
+
print("print_modulus:", args.print_modulus)
|
| 61 |
+
print("")
|
| 62 |
+
print("n_workers:", args.n_workers)
|
| 63 |
+
print("force_cpu:", args.force_cpu)
|
| 64 |
+
print("tensorboard:", not args.no_tensorboard)
|
| 65 |
+
print("")
|
| 66 |
+
print("continue_weights:", args.continue_weights)
|
| 67 |
+
print("continue_epoch:", args.continue_epoch)
|
| 68 |
+
print("")
|
| 69 |
+
print("lr:", args.lr)
|
| 70 |
+
print("ce_smoothing:", args.ce_smoothing)
|
| 71 |
+
print("batch_size:", args.batch_size)
|
| 72 |
+
print("epochs:", args.epochs)
|
| 73 |
+
print("")
|
| 74 |
+
print("rpr:", args.rpr)
|
| 75 |
+
print("max_sequence:", args.max_sequence)
|
| 76 |
+
print("n_layers:", args.n_layers)
|
| 77 |
+
print("num_heads:", args.num_heads)
|
| 78 |
+
print("d_model:", args.d_model)
|
| 79 |
+
print("")
|
| 80 |
+
print("dim_feedforward:", args.dim_feedforward)
|
| 81 |
+
print("dropout:", args.dropout)
|
| 82 |
+
print(SEPERATOR)
|
| 83 |
+
print("")
|
| 84 |
+
|
| 85 |
+
# parse_eval_args
|
| 86 |
+
def parse_eval_args():
|
| 87 |
+
"""
|
| 88 |
+
----------
|
| 89 |
+
Author: Damon Gwinn
|
| 90 |
+
----------
|
| 91 |
+
Argparse arguments for evaluating a model
|
| 92 |
+
----------
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
parser = argparse.ArgumentParser()
|
| 96 |
+
|
| 97 |
+
parser.add_argument("-dataset_dir", type=str, default="./dataset/e_piano", help="Folder of preprocessed and pickled midi files")
|
| 98 |
+
parser.add_argument("-model_weights", type=str, default="./saved_models/model.pickle", help="Pickled model weights file saved with torch.save and model.state_dict()")
|
| 99 |
+
parser.add_argument("-n_workers", type=int, default=1, help="Number of threads for the dataloader")
|
| 100 |
+
parser.add_argument("--force_cpu", action="store_true", help="Forces model to run on a cpu even when gpu is available")
|
| 101 |
+
|
| 102 |
+
parser.add_argument("-batch_size", type=int, default=2, help="Batch size to use")
|
| 103 |
+
|
| 104 |
+
parser.add_argument("--rpr", action="store_true", help="Use a modified Transformer for Relative Position Representations")
|
| 105 |
+
parser.add_argument("-max_sequence", type=int, default=2048, help="Maximum midi sequence to consider in the model")
|
| 106 |
+
parser.add_argument("-n_layers", type=int, default=6, help="Number of decoder layers to use")
|
| 107 |
+
parser.add_argument("-num_heads", type=int, default=8, help="Number of heads to use for multi-head attention")
|
| 108 |
+
parser.add_argument("-d_model", type=int, default=512, help="Dimension of the model (output dim of embedding layers, etc.)")
|
| 109 |
+
|
| 110 |
+
parser.add_argument("-dim_feedforward", type=int, default=1024, help="Dimension of the feedforward layer")
|
| 111 |
+
|
| 112 |
+
return parser.parse_args()
|
| 113 |
+
|
| 114 |
+
# print_eval_args
|
| 115 |
+
def print_eval_args(args):
|
| 116 |
+
"""
|
| 117 |
+
----------
|
| 118 |
+
Author: Damon Gwinn
|
| 119 |
+
----------
|
| 120 |
+
Prints evaluation arguments
|
| 121 |
+
----------
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
print(SEPERATOR)
|
| 125 |
+
print("dataset_dir:", args.dataset_dir)
|
| 126 |
+
print("model_weights:", args.model_weights)
|
| 127 |
+
print("n_workers:", args.n_workers)
|
| 128 |
+
print("force_cpu:", args.force_cpu)
|
| 129 |
+
print("")
|
| 130 |
+
print("batch_size:", args.batch_size)
|
| 131 |
+
print("")
|
| 132 |
+
print("rpr:", args.rpr)
|
| 133 |
+
print("max_sequence:", args.max_sequence)
|
| 134 |
+
print("n_layers:", args.n_layers)
|
| 135 |
+
print("num_heads:", args.num_heads)
|
| 136 |
+
print("d_model:", args.d_model)
|
| 137 |
+
print("")
|
| 138 |
+
print("dim_feedforward:", args.dim_feedforward)
|
| 139 |
+
print(SEPERATOR)
|
| 140 |
+
print("")
|
| 141 |
+
|
| 142 |
+
# parse_generate_args
|
| 143 |
+
def parse_generate_args():
|
| 144 |
+
"""
|
| 145 |
+
----------
|
| 146 |
+
Author: Damon Gwinn
|
| 147 |
+
----------
|
| 148 |
+
Argparse arguments for generation
|
| 149 |
+
----------
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
parser = argparse.ArgumentParser()
|
| 153 |
+
|
| 154 |
+
parser.add_argument("-midi_root", type=str, default="./dataset/e_piano/", help="Midi file to prime the generator with")
|
| 155 |
+
parser.add_argument("-output_dir", type=str, default="./gen", help="Folder to write generated midi to")
|
| 156 |
+
parser.add_argument("-primer_file", type=str, default=None, help="File path or integer index to the evaluation dataset. Default is to select a random index.")
|
| 157 |
+
parser.add_argument("--force_cpu", action="store_true", help="Forces model to run on a cpu even when gpu is available")
|
| 158 |
+
|
| 159 |
+
parser.add_argument("-target_seq_length", type=int, default=1024, help="Target length you'd like the midi to be")
|
| 160 |
+
parser.add_argument("-num_prime", type=int, default=256, help="Amount of messages to prime the generator with")
|
| 161 |
+
parser.add_argument("-model_weights", type=str, default="./saved_models/model.pickle", help="Pickled model weights file saved with torch.save and model.state_dict()")
|
| 162 |
+
parser.add_argument("-beam", type=int, default=0, help="Beam search k. 0 for random probability sample and 1 for greedy")
|
| 163 |
+
|
| 164 |
+
parser.add_argument("--rpr", action="store_true", help="Use a modified Transformer for Relative Position Representations")
|
| 165 |
+
parser.add_argument("-max_sequence", type=int, default=2048, help="Maximum midi sequence to consider")
|
| 166 |
+
parser.add_argument("-n_layers", type=int, default=6, help="Number of decoder layers to use")
|
| 167 |
+
parser.add_argument("-num_heads", type=int, default=8, help="Number of heads to use for multi-head attention")
|
| 168 |
+
parser.add_argument("-d_model", type=int, default=512, help="Dimension of the model (output dim of embedding layers, etc.)")
|
| 169 |
+
|
| 170 |
+
parser.add_argument("-dim_feedforward", type=int, default=1024, help="Dimension of the feedforward layer")
|
| 171 |
+
|
| 172 |
+
return parser.parse_args()
|
| 173 |
+
|
| 174 |
+
# print_generate_args
|
| 175 |
+
def print_generate_args(args):
|
| 176 |
+
"""
|
| 177 |
+
----------
|
| 178 |
+
Author: Damon Gwinn
|
| 179 |
+
----------
|
| 180 |
+
Prints generation arguments
|
| 181 |
+
----------
|
| 182 |
+
"""
|
| 183 |
+
|
| 184 |
+
print(SEPERATOR)
|
| 185 |
+
print("midi_root:", args.midi_root)
|
| 186 |
+
print("output_dir:", args.output_dir)
|
| 187 |
+
print("primer_file:", args.primer_file)
|
| 188 |
+
print("force_cpu:", args.force_cpu)
|
| 189 |
+
print("")
|
| 190 |
+
print("target_seq_length:", args.target_seq_length)
|
| 191 |
+
print("num_prime:", args.num_prime)
|
| 192 |
+
print("model_weights:", args.model_weights)
|
| 193 |
+
print("beam:", args.beam)
|
| 194 |
+
print("")
|
| 195 |
+
print("rpr:", args.rpr)
|
| 196 |
+
print("max_sequence:", args.max_sequence)
|
| 197 |
+
print("n_layers:", args.n_layers)
|
| 198 |
+
print("num_heads:", args.num_heads)
|
| 199 |
+
print("d_model:", args.d_model)
|
| 200 |
+
print("")
|
| 201 |
+
print("dim_feedforward:", args.dim_feedforward)
|
| 202 |
+
print(SEPERATOR)
|
| 203 |
+
print("")
|
| 204 |
+
|
| 205 |
+
# write_model_params
|
| 206 |
+
def write_model_params(args, output_file):
|
| 207 |
+
"""
|
| 208 |
+
----------
|
| 209 |
+
Author: Damon Gwinn
|
| 210 |
+
----------
|
| 211 |
+
Writes given training parameters to text file
|
| 212 |
+
----------
|
| 213 |
+
"""
|
| 214 |
+
|
| 215 |
+
o_stream = open(output_file, "w")
|
| 216 |
+
|
| 217 |
+
o_stream.write("rpr: " + str(args.rpr) + "\n")
|
| 218 |
+
o_stream.write("lr: " + str(args.lr) + "\n")
|
| 219 |
+
o_stream.write("ce_smoothing: " + str(args.ce_smoothing) + "\n")
|
| 220 |
+
o_stream.write("batch_size: " + str(args.batch_size) + "\n")
|
| 221 |
+
o_stream.write("max_sequence: " + str(args.max_sequence) + "\n")
|
| 222 |
+
o_stream.write("n_layers: " + str(args.n_layers) + "\n")
|
| 223 |
+
o_stream.write("num_heads: " + str(args.num_heads) + "\n")
|
| 224 |
+
o_stream.write("d_model: " + str(args.d_model) + "\n")
|
| 225 |
+
o_stream.write("dim_feedforward: " + str(args.dim_feedforward) + "\n")
|
| 226 |
+
o_stream.write("dropout: " + str(args.dropout) + "\n")
|
| 227 |
+
|
| 228 |
+
o_stream.close()
|
utilities/constants.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from processor import RANGE_NOTE_ON, RANGE_NOTE_OFF, RANGE_VEL, RANGE_TIME_SHIFT
|
| 4 |
+
|
| 5 |
+
SEPERATOR = "========================="
|
| 6 |
+
|
| 7 |
+
# Taken from the paper
|
| 8 |
+
ADAM_BETA_1 = 0.9
|
| 9 |
+
ADAM_BETA_2 = 0.98
|
| 10 |
+
ADAM_EPSILON = 10e-9
|
| 11 |
+
|
| 12 |
+
LR_DEFAULT_START = 1.0
|
| 13 |
+
SCHEDULER_WARMUP_STEPS = 4000
|
| 14 |
+
# LABEL_SMOOTHING_E = 0.1
|
| 15 |
+
|
| 16 |
+
# DROPOUT_P = 0.1
|
| 17 |
+
|
| 18 |
+
TOKEN_END = RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_VEL + RANGE_TIME_SHIFT
|
| 19 |
+
TOKEN_PAD = TOKEN_END + 1
|
| 20 |
+
|
| 21 |
+
VOCAB_SIZE = TOKEN_PAD + 1
|
| 22 |
+
|
| 23 |
+
TORCH_FLOAT = torch.float32
|
| 24 |
+
TORCH_INT = torch.int32
|
| 25 |
+
|
| 26 |
+
TORCH_LABEL_TYPE = torch.long
|
| 27 |
+
|
| 28 |
+
PREPEND_ZEROS_WIDTH = 4
|
utilities/device.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# For all things related to devices
|
| 2 |
+
#### ONLY USE PROVIDED FUNCTIONS, DO NOT USE GLOBAL CONSTANTS ####
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
TORCH_CPU_DEVICE = torch.device("cpu")
|
| 7 |
+
|
| 8 |
+
if(torch.cuda.device_count() > 0):
|
| 9 |
+
TORCH_CUDA_DEVICE = torch.device("cuda")
|
| 10 |
+
else:
|
| 11 |
+
print("----- WARNING: CUDA devices not detected. This will cause the model to run very slow! -----")
|
| 12 |
+
print("")
|
| 13 |
+
TORCH_CUDA_DEVICE = None
|
| 14 |
+
|
| 15 |
+
USE_CUDA = True
|
| 16 |
+
|
| 17 |
+
# use_cuda
|
| 18 |
+
def use_cuda(cuda_bool):
|
| 19 |
+
"""
|
| 20 |
+
----------
|
| 21 |
+
Author: Damon Gwinn
|
| 22 |
+
----------
|
| 23 |
+
Sets whether to use CUDA (if available), or use the CPU (not recommended)
|
| 24 |
+
----------
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
global USE_CUDA
|
| 28 |
+
USE_CUDA = cuda_bool
|
| 29 |
+
|
| 30 |
+
# get_device
|
| 31 |
+
def get_device():
|
| 32 |
+
"""
|
| 33 |
+
----------
|
| 34 |
+
Author: Damon Gwinn
|
| 35 |
+
----------
|
| 36 |
+
Grabs the default device. Default device is CUDA if available and use_cuda is not False, CPU otherwise.
|
| 37 |
+
----------
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
if((not USE_CUDA) or (TORCH_CUDA_DEVICE is None)):
|
| 41 |
+
return TORCH_CPU_DEVICE
|
| 42 |
+
else:
|
| 43 |
+
return TORCH_CUDA_DEVICE
|
| 44 |
+
|
| 45 |
+
# cuda_device
|
| 46 |
+
def cuda_device():
|
| 47 |
+
"""
|
| 48 |
+
----------
|
| 49 |
+
Author: Damon Gwinn
|
| 50 |
+
----------
|
| 51 |
+
Grabs the cuda device (may be None if CUDA is not available)
|
| 52 |
+
----------
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
return TORCH_CUDA_DEVICE
|
| 56 |
+
|
| 57 |
+
# cpu_device
|
| 58 |
+
def cpu_device():
|
| 59 |
+
"""
|
| 60 |
+
----------
|
| 61 |
+
Author: Damon Gwinn
|
| 62 |
+
----------
|
| 63 |
+
Grabs the cpu device
|
| 64 |
+
----------
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
return TORCH_CPU_DEVICE
|