utkucoban commited on
Commit
47dfee0
·
verified ·
1 Parent(s): a069e25

NanoMaestro Full model weights released

Browse files
app.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import random
5
+ import glob
6
+ import numpy as np
7
+ import pretty_midi
8
+ import scipy.io.wavfile
9
+
10
+ # --- Dependencies from your project ---
11
+ # (Make sure these files are in the same directory)
12
+ try:
13
+ from model.music_transformer import MusicTransformer
14
+ from processor import encode_midi, decode_midi
15
+ from dataset.e_piano import process_midi
16
+ from utilities.constants import *
17
+ from utilities.device import get_device, use_cuda
18
+ except ImportError as e:
19
+ print(f"Error: Could not import necessary files.")
20
+ print(f"Make sure app.py is in the same folder as 'model', 'processor.py', etc.")
21
+ print(f"Details: {e}")
22
+ exit()
23
+
24
+ # --- Your Model's Hyperparameters ---
25
+ # (Pulled from your training logs)
26
+ MODEL_CONFIG = {
27
+ "n_layers": 6,
28
+ "num_heads": 8,
29
+ "d_model": 512,
30
+ "dim_feedforward": 1024,
31
+ "max_sequence": 2048,
32
+ "rpr": True
33
+ }
34
+ # ------------------------------------
35
+
36
+ # Global variable to hold the loaded model
37
+ model = None
38
+ device = get_device()
39
+ print(f"Using device: {device}")
40
+
41
+
42
+ def load_model(model_path):
43
+ """
44
+ Loads the trained MusicTransformer model into memory.
45
+ """
46
+ global model
47
+ if model_path is None or not os.path.exists(model_path):
48
+ return "Error: Model file not found. Please check the path."
49
+
50
+ try:
51
+ print("Loading model...")
52
+ model = MusicTransformer(
53
+ n_layers=MODEL_CONFIG["n_layers"],
54
+ num_heads=MODEL_CONFIG["num_heads"],
55
+ d_model=MODEL_CONFIG["d_model"],
56
+ dim_feedforward=MODEL_CONFIG["dim_feedforward"],
57
+ max_sequence=MODEL_CONFIG["max_sequence"],
58
+ rpr=MODEL_CONFIG["rpr"]
59
+ ).to(device)
60
+
61
+ # Load the weights, mapping to the correct device
62
+ model.load_state_dict(
63
+ torch.load(model_path, map_location=device, weights_only=True)
64
+ )
65
+ model.eval()
66
+ print("Model loaded successfully.")
67
+ return f"Model '{model_path}' loaded successfully."
68
+ except Exception as e:
69
+ return f"Error loading model: {e}"
70
+
71
+
72
+ # --- NEW FUNCTION ---
73
+ def midi_to_wav(midi_file_path, wav_file_path):
74
+ """
75
+ Synthesizes a MIDI file to a WAV file using pretty_midi's
76
+ built-in (simple) sine wave synthesizer.
77
+ """
78
+ try:
79
+ pm = pretty_midi.PrettyMIDI(midi_file_path)
80
+ # Synthesize the audio at a 44.1kHz sample rate
81
+ audio_data = pm.synthesize(fs=44100)
82
+ # Write as a 16-bit WAV file
83
+ scipy.io.wavfile.write(wav_file_path, 44100, audio_data.astype(np.int16))
84
+ return wav_file_path
85
+ except Exception as e:
86
+ print(f"Error during MIDI to WAV conversion: {e}")
87
+ return None
88
+
89
+
90
+ # --- END NEW FUNCTION ---
91
+
92
+ def generate_music(primer_type, uploaded_midi, upload_start_location, maestro_path, maestro_start_location,
93
+ primer_length, generation_length_new, progress=gr.Progress(track_tqdm=True)):
94
+ """
95
+ The main function called by the Gradio button.
96
+ """
97
+ global model
98
+ if model is None:
99
+ # --- MODIFICATION: Return 3 values on error ---
100
+ yield "Error: Model is not loaded. Please load a model first.", None, None
101
+
102
+ try:
103
+ # --- 1. Prepare the Primer ---
104
+ primer = None
105
+ num_primer = 0
106
+
107
+ total_target_length = primer_length + generation_length_new
108
+ if total_target_length > MODEL_CONFIG["max_sequence"]:
109
+ total_target_length = MODEL_CONFIG["max_sequence"]
110
+ yield f"Warning: Clamping to {total_target_length} tokens.", None, None
111
+
112
+ if primer_type == "Generate from Silence":
113
+ yield "Generating from silence...", None, None
114
+ primer = torch.tensor([372], dtype=TORCH_LABEL_TYPE, device=device)
115
+ num_primer = 1
116
+
117
+ elif primer_type == "Random Maestro MIDI":
118
+ yield "Finding random Maestro file...", None, None
119
+ if maestro_path is None or not os.path.isdir(maestro_path):
120
+ yield f"Error: Maestro path '{maestro_path}' is not valid.", None, None
121
+ return
122
+
123
+ midi_files = glob.glob(os.path.join(maestro_path, "**", "*.mid"), recursive=True) + \
124
+ glob.glob(os.path.join(maestro_path, "**", "*.midi"), recursive=True)
125
+
126
+ if not midi_files:
127
+ yield f"Error: No .mid/.midi files found in '{maestro_path}'.", None, None
128
+ return
129
+
130
+ random_file = random.choice(midi_files)
131
+ yield f"Tokenizing random file: {os.path.basename(random_file)}...", None, None
132
+ raw_mid = encode_midi(random_file)
133
+
134
+ is_random_start = (maestro_start_location == "Random Location")
135
+ primer_tokens, _ = process_midi(raw_mid, primer_length, random_seq=is_random_start)
136
+
137
+ primer = torch.tensor(primer_tokens, dtype=TORCH_LABEL_TYPE, device=device)
138
+ num_primer = primer.shape[0]
139
+
140
+ elif primer_type == "Upload My Own MIDI":
141
+ if uploaded_midi is None:
142
+ yield "Error: Please upload a MIDI file.", None, None
143
+ return
144
+
145
+ yield f"Tokenizing uploaded MIDI: {os.path.basename(uploaded_midi.name)}...", None, None
146
+ raw_mid = encode_midi(uploaded_midi.name)
147
+ if not raw_mid:
148
+ yield "Error: Could not read MIDI messages.", None, None
149
+ return
150
+
151
+ is_random_start = (upload_start_location == "Random Location")
152
+ primer_tokens, _ = process_midi(raw_mid, primer_length, random_seq=is_random_start)
153
+ primer = torch.tensor(primer_tokens, dtype=TORCH_LABEL_TYPE, device=device)
154
+ num_primer = primer.shape[0]
155
+
156
+ if num_primer == 0:
157
+ yield "Error: Primer processing resulted in 0 tokens.", None, None
158
+ return
159
+
160
+ # --- 2. Run Generation ---
161
+ yield f"Primed with {num_primer} tokens. Generating {generation_length_new} new tokens...", None, None
162
+
163
+ primer_batch = primer.unsqueeze(0)
164
+
165
+ model.eval()
166
+ with torch.set_grad_enabled(False):
167
+ rand_seq = model.generate(primer_batch, total_target_length, beam=0)
168
+
169
+ # --- 3. Process and Save Output ---
170
+ generated_only_tokens = rand_seq[0][num_primer:]
171
+
172
+ if len(generated_only_tokens) == 0:
173
+ yield "Warning: Generation produced 0 new tokens.", None, None
174
+ return
175
+
176
+ # --- MODIFICATION: Define output paths ---
177
+ midi_output_filename = "generation_output.mid"
178
+ wav_output_filename = "generation_output.wav"
179
+
180
+ # Save the MIDI file
181
+ decode_midi(generated_only_tokens.cpu().numpy(), midi_output_filename)
182
+
183
+ # --- MODIFICATION: Synthesize MIDI to WAV ---
184
+ yield "Synthesizing audio...", midi_output_filename, None
185
+ wav_path = midi_to_wav(midi_output_filename, wav_output_filename)
186
+
187
+ if wav_path:
188
+ yield "Generation Complete!", midi_output_filename, wav_path
189
+ else:
190
+ yield "Generation complete (WAV synthesis failed).", midi_output_filename, None
191
+
192
+ except Exception as e:
193
+ yield f"An error occurred: {e}", None, None
194
+
195
+
196
+ # --- Build the Gradio UI ---
197
+ with gr.Blocks(theme=gr.themes.Soft()) as app:
198
+ gr.Markdown("# 🎹 Music Transformer Generation UI")
199
+ gr.Markdown("Load your trained model and generate music from silence, a random seed, or your own MIDI file.")
200
+
201
+ with gr.Row():
202
+ with gr.Column(scale=1):
203
+ gr.Markdown("### 1. Load Model")
204
+ model_path_input = gr.Textbox(
205
+ label="Path to your .pickle model file",
206
+ value="best_acc_weights.pickle"
207
+ )
208
+ load_button = gr.Button("Load Model", variant="primary")
209
+ load_status = gr.Textbox(label="Model Status", interactive=False)
210
+
211
+ with gr.Column(scale=2):
212
+ gr.Markdown("### 2. Configure Generation")
213
+
214
+ primer_type_input = gr.Radio(
215
+ label="Choose Primer Type",
216
+ choices=["Generate from Silence", "Random Maestro MIDI", "Upload My Own MIDI"],
217
+ value="Generate from Silence"
218
+ )
219
+
220
+ with gr.Column(visible=False) as maestro_options:
221
+ maestro_path_input = gr.Textbox(
222
+ label="Path to RAW Maestro MIDI Folder (searches all subfolders)",
223
+ value="./maestro-v2.0.0"
224
+ )
225
+ maestro_start_location_input = gr.Radio(
226
+ label="Primer Start Location",
227
+ choices=["Start of File", "Random Location"],
228
+ value="Random Location",
229
+ info="Selects a random chunk from the file, giving more variety."
230
+ )
231
+
232
+ with gr.Column(visible=False) as upload_options:
233
+ uploaded_midi_input = gr.File(
234
+ label="Upload Your MIDI Primer",
235
+ file_types=[".mid", ".midi"]
236
+ )
237
+ upload_start_location_input = gr.Radio(
238
+ label="Primer Start Location",
239
+ choices=["Start of File", "Random Location"],
240
+ value="Start of File"
241
+ )
242
+
243
+ primer_length_slider = gr.Slider(
244
+ label="Primer Length (Tokens)",
245
+ minimum=64,
246
+ maximum=2000,
247
+ value=512,
248
+ step=32,
249
+ info="How many tokens to use from the primer file. Ignored for 'Silence'."
250
+ )
251
+
252
+ generation_length_slider = gr.Slider(
253
+ label="New Tokens to Generate",
254
+ minimum=128,
255
+ maximum=2048,
256
+ value=1024,
257
+ step=32,
258
+ info="How many new tokens to create after the primer."
259
+ )
260
+
261
+ generate_button = gr.Button("Generate Music", variant="primary")
262
+
263
+ with gr.Row():
264
+ gr.Markdown("### 3. Get Your Music")
265
+ status_output = gr.Textbox(label="Status", interactive=False)
266
+ with gr.Row():
267
+ output_midi_file = gr.File(label="Download Generated MIDI")
268
+ # --- MODIFICATION: Added Audio player ---
269
+ output_wav_file = gr.Audio(label="Listen to Generated WAV", type="filepath")
270
+ # --- END MODIFICATION ---
271
+
272
+
273
+ # --- UI Event Listeners ---
274
+
275
+ def update_ui(primer_type):
276
+ return {
277
+ maestro_options: gr.Column(visible=(primer_type == "Random Maestro MIDI")),
278
+ upload_options: gr.Column(visible=(primer_type == "Upload My Own MIDI")),
279
+ primer_length_slider: gr.Slider(visible=(primer_type != "Generate from Silence"))
280
+ }
281
+
282
+
283
+ primer_type_input.change(
284
+ fn=update_ui,
285
+ inputs=primer_type_input,
286
+ outputs=[maestro_options, upload_options, primer_length_slider]
287
+ )
288
+
289
+ load_button.click(
290
+ fn=load_model,
291
+ inputs=model_path_input,
292
+ outputs=load_status
293
+ )
294
+
295
+ # --- MODIFICATION: Updated outputs list ---
296
+ generate_button.click(
297
+ fn=generate_music,
298
+ inputs=[
299
+ primer_type_input,
300
+ uploaded_midi_input,
301
+ upload_start_location_input,
302
+ maestro_path_input,
303
+ maestro_start_location_input,
304
+ primer_length_slider,
305
+ generation_length_slider
306
+ ],
307
+ outputs=[status_output, output_midi_file, output_wav_file] # <-- Added WAV output
308
+ )
309
+ # --- END MODIFICATION ---
310
+
311
+ if __name__ == "__main__":
312
+ # Check if CUDA is available and set device
313
+ if (not torch.cuda.is_available()):
314
+ print("----- WARNING: CUDA devices not detected. This will cause the model to run very slow! -----")
315
+ use_cuda(False)
316
+
317
+ print("Launching Gradio UI...")
318
+ app.launch()
best_acc_weights.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bba9e1c91f449753383895c224383dd0ab8402ed003d0d6eb83a4d3f9a3de5df
3
+ size 59442741
dataset/__init__.py ADDED
File without changes
dataset/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (153 Bytes). View file
 
dataset/__pycache__/e_piano.cpython-312.pyc ADDED
Binary file (5.88 kB). View file
 
dataset/e_piano.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import random
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.utils.data import Dataset
7
+
8
+ from utilities.constants import *
9
+ from utilities.device import cpu_device
10
+
11
+ SEQUENCE_START = 0
12
+
13
+ # EPianoDataset
14
+ class EPianoDataset(Dataset):
15
+ """
16
+ ----------
17
+ Author: Damon Gwinn
18
+ ----------
19
+ Pytorch Dataset for the Maestro e-piano dataset (https://magenta.tensorflow.org/datasets/maestro).
20
+ Recommended to use with Dataloader (https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)
21
+
22
+ Uses all files found in the given root directory of pre-processed (preprocess_midi.py)
23
+ Maestro midi files.
24
+ ----------
25
+ """
26
+
27
+ def __init__(self, root, max_seq=2048, random_seq=True):
28
+ self.root = root
29
+ self.max_seq = max_seq
30
+ self.random_seq = random_seq
31
+
32
+ fs = [os.path.join(root, f) for f in os.listdir(self.root)]
33
+ self.data_files = [f for f in fs if os.path.isfile(f)]
34
+
35
+ # __len__
36
+ def __len__(self):
37
+ """
38
+ ----------
39
+ Author: Damon Gwinn
40
+ ----------
41
+ How many data files exist in the given directory
42
+ ----------
43
+ """
44
+
45
+ return len(self.data_files)
46
+
47
+ # __getitem__
48
+ def __getitem__(self, idx):
49
+ """
50
+ ----------
51
+ Author: Damon Gwinn
52
+ ----------
53
+ Gets the indexed midi batch. Gets random sequence or from start depending on random_seq.
54
+
55
+ Returns the input and the target.
56
+ ----------
57
+ """
58
+
59
+ # All data on cpu to allow for the Dataloader to multithread
60
+ i_stream = open(self.data_files[idx], "rb")
61
+ # return pickle.load(i_stream), None
62
+ raw_mid = torch.tensor(pickle.load(i_stream), dtype=TORCH_LABEL_TYPE, device=cpu_device())
63
+ i_stream.close()
64
+
65
+ x, tgt = process_midi(raw_mid, self.max_seq, self.random_seq)
66
+
67
+ return x, tgt
68
+
69
+ # process_midi
70
+ def process_midi(raw_mid, max_seq, random_seq):
71
+ """
72
+ ----------
73
+ Author: Damon Gwinn
74
+ ----------
75
+ Takes in pre-processed raw midi and returns the input and target. Can use a random sequence or
76
+ go from the start based on random_seq.
77
+ ----------
78
+ """
79
+
80
+ x = torch.full((max_seq, ), TOKEN_PAD, dtype=TORCH_LABEL_TYPE, device=cpu_device())
81
+ tgt = torch.full((max_seq, ), TOKEN_PAD, dtype=TORCH_LABEL_TYPE, device=cpu_device())
82
+
83
+ raw_len = len(raw_mid)
84
+ full_seq = max_seq + 1 # Performing seq2seq
85
+
86
+ if(raw_len == 0):
87
+ return x, tgt
88
+
89
+ if(raw_len < full_seq):
90
+ x[:raw_len] = raw_mid
91
+ tgt[:raw_len-1] = raw_mid[1:]
92
+ tgt[raw_len] = TOKEN_END
93
+ else:
94
+ # Randomly selecting a range
95
+ if(random_seq):
96
+ end_range = raw_len - full_seq
97
+ start = random.randint(SEQUENCE_START, end_range)
98
+
99
+ # Always taking from the start to as far as we can
100
+ else:
101
+ start = SEQUENCE_START
102
+
103
+ end = start + full_seq
104
+
105
+ data = raw_mid[start:end]
106
+
107
+ x = data[:max_seq]
108
+ tgt = data[1:full_seq]
109
+
110
+
111
+ # print("x:",x)
112
+ # print("tgt:",tgt)
113
+
114
+ return x, tgt
115
+
116
+
117
+ # create_epiano_datasets
118
+ def create_epiano_datasets(dataset_root, max_seq, random_seq=True):
119
+ """
120
+ ----------
121
+ Author: Damon Gwinn
122
+ ----------
123
+ Creates train, evaluation, and test EPianoDataset objects for a pre-processed (preprocess_midi.py)
124
+ root containing train, val, and test folders.
125
+ ----------
126
+ """
127
+
128
+ train_root = os.path.join(dataset_root, "train")
129
+ val_root = os.path.join(dataset_root, "val")
130
+ test_root = os.path.join(dataset_root, "test")
131
+
132
+ train_dataset = EPianoDataset(train_root, max_seq, random_seq)
133
+ val_dataset = EPianoDataset(val_root, max_seq, random_seq)
134
+ test_dataset = EPianoDataset(test_root, max_seq, random_seq)
135
+
136
+ return train_dataset, val_dataset, test_dataset
137
+
138
+ # compute_epiano_accuracy
139
+ def compute_epiano_accuracy(out, tgt):
140
+ """
141
+ ----------
142
+ Author: Damon Gwinn
143
+ ----------
144
+ Computes the average accuracy for the given input and output batches. Accuracy uses softmax
145
+ of the output.
146
+ ----------
147
+ """
148
+
149
+ softmax = nn.Softmax(dim=-1)
150
+ out = torch.argmax(softmax(out), dim=-1)
151
+
152
+ out = out.flatten()
153
+ tgt = tgt.flatten()
154
+
155
+ mask = (tgt != TOKEN_PAD)
156
+
157
+ out = out[mask]
158
+ tgt = tgt[mask]
159
+
160
+ # Empty
161
+ if(len(tgt) == 0):
162
+ return 1.0
163
+
164
+ num_right = (out == tgt)
165
+ num_right = torch.sum(num_right).type(TORCH_FLOAT)
166
+
167
+ acc = num_right / len(tgt)
168
+
169
+ return acc
generate.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import os
4
+ import random
5
+ import pretty_midi
6
+ import processor
7
+
8
+ from processor import encode_midi, decode_midi
9
+
10
+ from utilities.argument_funcs import parse_generate_args, print_generate_args
11
+ from model.music_transformer import MusicTransformer
12
+ from dataset.e_piano import create_epiano_datasets, compute_epiano_accuracy, process_midi
13
+ from torch.utils.data import DataLoader
14
+ from torch.optim import Adam
15
+
16
+ from utilities.constants import *
17
+ from utilities.device import get_device, use_cuda
18
+
19
+
20
+ # main
21
+ def main():
22
+ """
23
+ ----------
24
+ Author: Damon Gwinn
25
+ ----------
26
+ Entry point. Generates music from a model specified by command line arguments
27
+ ----------
28
+ """
29
+
30
+ args = parse_generate_args()
31
+ print_generate_args(args)
32
+
33
+ if (args.force_cpu):
34
+ use_cuda(False)
35
+ print("WARNING: Forced CPU usage, expect model to perform slower")
36
+ print("")
37
+
38
+ os.makedirs(args.output_dir, exist_ok=True)
39
+
40
+ # --- MODIFIED LOGIC ---
41
+ # Can be None, an integer index to dataset, or a file path
42
+ if (args.primer_file is None):
43
+ # --- Load dataset ONLY if no primer file is given ---
44
+ print("No primer file provided, loading dataset to pick a random primer...")
45
+ _, _, dataset = create_epiano_datasets(args.midi_root, args.num_prime, random_seq=False)
46
+ # --- --- --- --- --- --- --- --- --- --- --- --- --- ---
47
+
48
+ f = str(random.randrange(len(dataset)))
49
+ idx = int(f)
50
+ primer, _ = dataset[idx]
51
+ primer = primer.to(get_device())
52
+ num_primer = primer.shape[0]
53
+ print("Using primer index:", idx, "(", dataset.data_files[idx], ")")
54
+
55
+ else:
56
+ # --- Primer file is provided, NO DATASET NEEDED (unless it's an index) ---
57
+ f = args.primer_file
58
+
59
+ # --- NEW: Check for "silence" ---
60
+ if (f.lower() == "silence"):
61
+ print("Generating from silence...")
62
+ # Create a primer with one token: a medium velocity (64 // 4 = 16)
63
+ # Velocity START_IDX = 356. Token = 356 + 16 = 372
64
+ primer = torch.tensor([372], dtype=TORCH_LABEL_TYPE, device=get_device())
65
+ num_primer = primer.shape[0] # This will be 1
66
+
67
+ # This part handles if the primer is an integer index (e.g., "3")
68
+ elif (f.isdigit()):
69
+ print("Primer file is an index, loading dataset...")
70
+ # --- Load dataset ONLY if primer is an index ---
71
+ _, _, dataset = create_epiano_datasets(args.midi_root, args.num_prime, random_seq=False)
72
+ # --- --- --- --- --- --- --- --- --- --- --- ---
73
+ idx = int(f)
74
+ primer, _ = dataset[idx]
75
+ primer = primer.to(get_device())
76
+ num_primer = primer.shape[0]
77
+ print("Using primer index:", idx, "(", dataset.data_files[idx], ")")
78
+
79
+ # This part handles if the primer is a MIDI file path (e.t., "my_primer.mid")
80
+ else:
81
+ print("Primer file is a MIDI path. Loading and tokenizing...")
82
+ raw_mid = encode_midi(f)
83
+ if (len(raw_mid) == 0):
84
+ print("Error: No midi messages in primer file:", f)
85
+ return
86
+
87
+ primer, _ = process_midi(raw_mid, args.num_prime, random_seq=False)
88
+ primer = torch.tensor(primer, dtype=TORCH_LABEL_TYPE, device=get_device())
89
+ num_primer = primer.shape[0] # Get the actual primer length
90
+ print("Using primer file:", f)
91
+
92
+ # --- END MODIFIED LOGIC ---
93
+
94
+ model = MusicTransformer(n_layers=args.n_layers, num_heads=args.num_heads,
95
+ d_model=args.d_model, dim_feedforward=args.dim_feedforward,
96
+ max_sequence=args.max_sequence, rpr=args.rpr).to(get_device())
97
+
98
+ model.load_state_dict(torch.load(args.model_weights))
99
+
100
+ # --- MODIFICATION: Don't save a primer if we started from silence ---
101
+ if (args.primer_file.lower() != "silence"):
102
+ f_path = os.path.join(args.output_dir, "primer.mid")
103
+ decode_midi(primer.cpu().numpy(), file_path=f_path)
104
+ # --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- ---
105
+
106
+ # GENERATION
107
+ model.eval()
108
+ with torch.set_grad_enabled(False):
109
+ if (args.beam > 0):
110
+ print("BEAM:", args.beam)
111
+
112
+ beam_seq = model.generate(primer.unsqueeze(0), args.target_seq_length, beam=args.beam)
113
+
114
+ # --- MODIFICATION: Slice the primer ---
115
+ generated_only = beam_seq[0][num_primer:]
116
+ # ------------------------------------
117
+
118
+ f_path = os.path.join(args.output_dir, "beam.mid")
119
+ decode_midi(generated_only.cpu().numpy(), file_path=f_path)
120
+ else:
121
+ print("RAND DIST")
122
+
123
+ rand_seq = model.generate(primer.unsqueeze(0), args.target_seq_length, beam=0)
124
+
125
+ # --- MODIFICATION: Slice the primer ---
126
+ generated_only = rand_seq[0][num_primer:]
127
+ # ------------------------------------
128
+
129
+ f_path = os.path.join(args.output_dir, "rand.mid")
130
+ decode_midi(generated_only.cpu().numpy(), file_path=f_path)
131
+
132
+
133
+ if __name__ == "__main__":
134
+ main()
model/__init__.py ADDED
File without changes
model/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (151 Bytes). View file
 
model/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (155 Bytes). View file
 
model/__pycache__/music_transformer.cpython-312.pyc ADDED
Binary file (8.13 kB). View file
 
model/__pycache__/music_transformer.cpython-313.pyc ADDED
Binary file (6.56 kB). View file
 
model/__pycache__/positional_encoding.cpython-312.pyc ADDED
Binary file (2.05 kB). View file
 
model/__pycache__/positional_encoding.cpython-313.pyc ADDED
Binary file (2.11 kB). View file
 
model/__pycache__/rpr.cpython-312.pyc ADDED
Binary file (19.3 kB). View file
 
model/__pycache__/rpr.cpython-313.pyc ADDED
Binary file (11.3 kB). View file
 
model/music_transformer.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn.modules.normalization import LayerNorm
4
+ import random
5
+
6
+ from utilities.constants import *
7
+ from utilities.device import get_device
8
+
9
+ from .positional_encoding import PositionalEncoding
10
+ from .rpr import TransformerEncoderRPR, TransformerEncoderLayerRPR
11
+
12
+
13
+ # MusicTransformer
14
+ class MusicTransformer(nn.Module):
15
+ def __init__(self, n_layers=6, num_heads=8, d_model=512, dim_feedforward=1024,
16
+ dropout=0.1, max_sequence=2048, rpr=False):
17
+ super(MusicTransformer, self).__init__()
18
+
19
+ self.dummy = DummyDecoder()
20
+
21
+ self.nlayers = n_layers
22
+ self.nhead = num_heads
23
+ self.d_model = d_model
24
+ self.d_ff = dim_feedforward
25
+ self.dropout = dropout
26
+ self.max_seq = max_sequence
27
+ self.rpr = rpr
28
+
29
+ # Input embedding
30
+ self.embedding = nn.Embedding(VOCAB_SIZE, self.d_model)
31
+
32
+ # Positional encoding
33
+ self.positional_encoding = PositionalEncoding(self.d_model, self.dropout, self.max_seq)
34
+
35
+ # Base transformer
36
+ if(not self.rpr):
37
+ self.transformer = nn.Transformer(
38
+ d_model=self.d_model, nhead=self.nhead, num_encoder_layers=self.nlayers,
39
+ num_decoder_layers=0, dropout=self.dropout,
40
+ dim_feedforward=self.d_ff, custom_decoder=self.dummy
41
+ )
42
+ else:
43
+ encoder_norm = LayerNorm(self.d_model)
44
+ encoder_layer = TransformerEncoderLayerRPR(self.d_model, self.nhead, self.d_ff, self.dropout, er_len=self.max_seq)
45
+ encoder = TransformerEncoderRPR(encoder_layer, self.nlayers, encoder_norm)
46
+ self.transformer = nn.Transformer(
47
+ d_model=self.d_model, nhead=self.nhead, num_encoder_layers=self.nlayers,
48
+ num_decoder_layers=0, dropout=self.dropout,
49
+ dim_feedforward=self.d_ff, custom_decoder=self.dummy, custom_encoder=encoder
50
+ )
51
+
52
+ # Final output is a softmaxed linear layer
53
+ self.Wout = nn.Linear(self.d_model, VOCAB_SIZE)
54
+ self.softmax = nn.Softmax(dim=-1)
55
+
56
+ # forward
57
+ def forward(self, x, mask=True):
58
+ # --- FIX: USE DEVICE OF INPUT TENSOR x ---
59
+ if(mask is True):
60
+ # Generate mask on the same device as input x
61
+ mask = self.transformer.generate_square_subsequent_mask(x.shape[1]).to(x.device)
62
+ else:
63
+ mask = None
64
+ # -----------------------------------------
65
+
66
+ x = self.embedding(x)
67
+
68
+ # Input shape is (max_seq, batch_size, d_model)
69
+ x = x.permute(1,0,2)
70
+
71
+ x = self.positional_encoding(x)
72
+
73
+ # Since there are no true decoder layers, the tgt is unused
74
+ x_out = self.transformer(src=x, tgt=x, src_mask=mask)
75
+
76
+ # Back to (batch_size, max_seq, d_model)
77
+ x_out = x_out.permute(1,0,2)
78
+
79
+ y = self.Wout(x_out)
80
+ return y
81
+
82
+ # generate
83
+ def generate(self, primer=None, target_seq_length=1024, beam=0, beam_chance=1.0):
84
+ assert (not self.training), "Cannot generate while in training mode"
85
+
86
+ print("Generating sequence of max length:", target_seq_length)
87
+
88
+ batch_size = primer.shape[0]
89
+ gen_seq = torch.full((batch_size, target_seq_length), TOKEN_PAD, dtype=TORCH_LABEL_TYPE, device=get_device())
90
+
91
+ num_primer = primer.shape[1]
92
+ gen_seq[..., :num_primer] = primer.type(TORCH_LABEL_TYPE).to(get_device())
93
+
94
+ cur_i = num_primer
95
+ while(cur_i < target_seq_length):
96
+ y = self.softmax(self.forward(gen_seq[..., :cur_i]))[..., :TOKEN_END]
97
+ token_probs = y[:, cur_i-1, :]
98
+
99
+ if(beam == 0):
100
+ beam_ran = 2.0
101
+ else:
102
+ beam_ran = random.uniform(0,1)
103
+
104
+ if(beam_ran <= beam_chance):
105
+ token_probs = token_probs.flatten()
106
+ top_res, top_i = torch.topk(token_probs, beam)
107
+
108
+ beam_rows = top_i // VOCAB_SIZE
109
+ beam_cols = top_i % VOCAB_SIZE
110
+
111
+ gen_seq = gen_seq[beam_rows, :]
112
+ gen_seq[..., cur_i] = beam_cols
113
+
114
+ else:
115
+ distrib = torch.distributions.categorical.Categorical(probs=token_probs)
116
+ next_token = distrib.sample()
117
+ gen_seq[:, cur_i] = next_token
118
+
119
+ if(next_token == TOKEN_END):
120
+ print("Model called end of sequence at:", cur_i, "/", target_seq_length)
121
+ break
122
+
123
+ cur_i += 1
124
+ if(cur_i % 50 == 0):
125
+ print(cur_i, "/", target_seq_length)
126
+
127
+ return gen_seq[:, :cur_i]
128
+
129
+ # Used as a dummy to nn.Transformer
130
+ class DummyDecoder(nn.Module):
131
+ def __init__(self):
132
+ super(DummyDecoder, self).__init__()
133
+
134
+ def forward(self, tgt, memory, tgt_mask, memory_mask,tgt_key_padding_mask,memory_key_padding_mask, **kwargs):
135
+ return memory
model/positional_encoding.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ # PositionalEncoding
6
+ # Taken from https://pytorch.org/tutorials/beginner/transformer_tutorial.html
7
+ class PositionalEncoding(nn.Module):
8
+
9
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
10
+ super(PositionalEncoding, self).__init__()
11
+ self.dropout = nn.Dropout(p=dropout)
12
+
13
+ pe = torch.zeros(max_len, d_model)
14
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
15
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
16
+ pe[:, 0::2] = torch.sin(position * div_term)
17
+ pe[:, 1::2] = torch.cos(position * div_term)
18
+ pe = pe.unsqueeze(0).transpose(0, 1)
19
+ self.register_buffer('pe', pe)
20
+
21
+ def forward(self, x):
22
+ x = x + self.pe[:x.size(0), :]
23
+ return self.dropout(x)
model/rpr.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ from torch.nn.parameter import Parameter
5
+ from torch.nn import Module
6
+ from torch.nn.modules.linear import Linear
7
+ from torch.nn.modules.dropout import Dropout
8
+ from torch.nn.modules.normalization import LayerNorm
9
+ from torch.nn.init import *
10
+
11
+ class TransformerEncoderRPR(Module):
12
+ def __init__(self, encoder_layer, num_layers, norm=None):
13
+ super(TransformerEncoderRPR, self).__init__()
14
+ self.layers = torch.nn.ModuleList([encoder_layer for _ in range(num_layers)]) # Fix for tracing
15
+ self.num_layers = num_layers
16
+ self.norm = norm
17
+
18
+ def forward(self, src, mask=None, src_key_padding_mask=None, **kwargs):
19
+ output = src
20
+ for layer in self.layers:
21
+ output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
22
+ if self.norm:
23
+ output = self.norm(output)
24
+ return output
25
+
26
+ class TransformerEncoderLayerRPR(Module):
27
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, er_len=None):
28
+ super(TransformerEncoderLayerRPR, self).__init__()
29
+ self.self_attn = MultiheadAttentionRPR(d_model, nhead, dropout=dropout, er_len=er_len)
30
+ self.linear1 = Linear(d_model, dim_feedforward)
31
+ self.dropout = Dropout(dropout)
32
+ self.linear2 = Linear(dim_feedforward, d_model)
33
+ self.norm1 = LayerNorm(d_model)
34
+ self.norm2 = LayerNorm(d_model)
35
+ self.dropout1 = Dropout(dropout)
36
+ self.dropout2 = Dropout(dropout)
37
+
38
+ def forward(self, src, src_mask=None, src_key_padding_mask=None):
39
+ src2 = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
40
+ src = src + self.dropout1(src2)
41
+ src = self.norm1(src)
42
+ src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))
43
+ src = src + self.dropout2(src2)
44
+ src = self.norm2(src)
45
+ return src
46
+
47
+ class MultiheadAttentionRPR(Module):
48
+ def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, er_len=None):
49
+ super(MultiheadAttentionRPR, self).__init__()
50
+ self.embed_dim = embed_dim
51
+ self.kdim = kdim if kdim is not None else embed_dim
52
+ self.vdim = vdim if vdim is not None else embed_dim
53
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
54
+ self.num_heads = num_heads
55
+ self.dropout = dropout
56
+ self.head_dim = embed_dim // num_heads
57
+
58
+ self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
59
+ if not self._qkv_same_embed_dim:
60
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
61
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
62
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
63
+
64
+ if bias:
65
+ self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
66
+ else:
67
+ self.register_parameter('in_proj_bias', None)
68
+ self.out_proj = Linear(embed_dim, embed_dim, bias=bias)
69
+ self.add_zero_attn = add_zero_attn
70
+
71
+ if er_len is not None:
72
+ self.Er = Parameter(torch.rand((er_len, self.head_dim), dtype=torch.float32))
73
+ else:
74
+ self.Er = None
75
+ self._reset_parameters()
76
+
77
+ def _reset_parameters(self):
78
+ if self._qkv_same_embed_dim: xavier_uniform_(self.in_proj_weight)
79
+ else:
80
+ xavier_uniform_(self.q_proj_weight)
81
+ xavier_uniform_(self.k_proj_weight)
82
+ xavier_uniform_(self.v_proj_weight)
83
+ if self.in_proj_bias is not None:
84
+ constant_(self.in_proj_bias, 0.)
85
+ constant_(self.out_proj.bias, 0.)
86
+
87
+ def forward(self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None):
88
+ return multi_head_attention_forward_rpr(
89
+ query, key, value, self.embed_dim, self.num_heads, self.head_dim,
90
+ self.in_proj_weight, self.in_proj_bias,
91
+ None, None, self.add_zero_attn,
92
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
93
+ training=self.training,
94
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
95
+ attn_mask=attn_mask, use_separate_proj_weight=not self._qkv_same_embed_dim,
96
+ q_proj_weight=getattr(self, 'q_proj_weight', None),
97
+ k_proj_weight=getattr(self, 'k_proj_weight', None),
98
+ v_proj_weight=getattr(self, 'v_proj_weight', None),
99
+ rpr_mat=self.Er)
100
+
101
+ def multi_head_attention_forward_rpr(query, key, value, embed_dim_to_check, num_heads, head_dim,
102
+ in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn,
103
+ dropout_p, out_proj_weight, out_proj_bias, training=True,
104
+ key_padding_mask=None, need_weights=True, attn_mask=None,
105
+ use_separate_proj_weight=False, q_proj_weight=None,
106
+ k_proj_weight=None, v_proj_weight=None, static_k=None,
107
+ static_v=None, rpr_mat=None):
108
+
109
+ tgt_len, bsz, embed_dim = query.size()
110
+ scaling = float(head_dim) ** -0.5
111
+
112
+ if not use_separate_proj_weight:
113
+ q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
114
+ else:
115
+ q = F.linear(query, q_proj_weight, in_proj_bias[0:embed_dim])
116
+ k = F.linear(key, k_proj_weight, in_proj_bias[embed_dim:(embed_dim * 2)])
117
+ v = F.linear(value, v_proj_weight, in_proj_bias[(embed_dim * 2):])
118
+
119
+ q = q * scaling
120
+ q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
121
+ k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
122
+ v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
123
+
124
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
125
+
126
+ if rpr_mat is not None:
127
+ # Safe Explicit Skew
128
+ len_q = q.shape[1]
129
+ start_idx = rpr_mat.shape[0] - len_q
130
+ rpr_mat_valid = rpr_mat[start_idx:, :]
131
+ qe = torch.einsum("hld,md->hlm", q, rpr_mat_valid)
132
+
133
+ # Indices logic (Flatten -> Gather -> Reshape)
134
+ B, L, _ = qe.shape
135
+ # Mask out upper triangle BEFORE skewing
136
+ mask_tri = torch.triu(torch.ones((L, L), device=qe.device, dtype=torch.bool)).flip(0)
137
+ qe = qe.masked_fill(~mask_tri, 0.0) # Fill with 0 before shift
138
+
139
+ zeros = torch.zeros((B, L, 1), device=qe.device, dtype=qe.dtype)
140
+ qe_pad = torch.cat([zeros, qe], dim=2).view(B, -1)
141
+
142
+ offsets = torch.arange(L * L, device=qe.device, dtype=torch.int64) + L
143
+ offsets = offsets.unsqueeze(0).expand(B, -1)
144
+ srel = torch.gather(qe_pad, 1, offsets).view(B, L, L)
145
+
146
+ attn_output_weights = attn_output_weights + srel
147
+
148
+ # --- MASKING FIX (Boolean Masked Fill) ---
149
+ if attn_mask is not None:
150
+ if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0)
151
+ # ONNX prefers masked_fill with boolean mask over adding -inf
152
+ is_causal_mask = (attn_mask == float('-inf')) | (attn_mask < -1e4)
153
+ attn_output_weights = attn_output_weights.masked_fill(is_causal_mask, float('-inf'))
154
+
155
+ if key_padding_mask is not None:
156
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, tgt_len)
157
+ attn_output_weights = attn_output_weights.masked_fill(
158
+ key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf')
159
+ )
160
+ attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, tgt_len)
161
+
162
+ attn_output_weights = F.softmax(attn_output_weights, dim=-1)
163
+ attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training)
164
+
165
+ attn_output = torch.bmm(attn_output_weights, v)
166
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
167
+ attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
168
+
169
+ if need_weights:
170
+ return attn_output, attn_output_weights.view(bsz, num_heads, tgt_len, tgt_len).sum(dim=1) / num_heads
171
+ return attn_output, None
processor.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pretty_midi
2
+
3
+
4
+ RANGE_NOTE_ON = 128
5
+ RANGE_NOTE_OFF = 128
6
+ RANGE_VEL = 32
7
+ RANGE_TIME_SHIFT = 100
8
+
9
+ START_IDX = {
10
+ 'note_on': 0,
11
+ 'note_off': RANGE_NOTE_ON,
12
+ 'time_shift': RANGE_NOTE_ON + RANGE_NOTE_OFF,
13
+ 'velocity': RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_TIME_SHIFT
14
+ }
15
+
16
+
17
+ class SustainAdapter:
18
+ def __init__(self, time, type):
19
+ self.start = time
20
+ self.type = type
21
+
22
+
23
+ class SustainDownManager:
24
+ def __init__(self, start, end):
25
+ self.start = start
26
+ self.end = end
27
+ self.managed_notes = []
28
+ self._note_dict = {} # key: pitch, value: note.start
29
+
30
+ def add_managed_note(self, note: pretty_midi.Note):
31
+ self.managed_notes.append(note)
32
+
33
+ def transposition_notes(self):
34
+ for note in reversed(self.managed_notes):
35
+ try:
36
+ note.end = self._note_dict[note.pitch]
37
+ except KeyError:
38
+ note.end = max(self.end, note.end)
39
+ self._note_dict[note.pitch] = note.start
40
+
41
+
42
+ # Divided note by note_on, note_off
43
+ class SplitNote:
44
+ def __init__(self, type, time, value, velocity):
45
+ ## type: note_on, note_off
46
+ self.type = type
47
+ self.time = time
48
+ self.velocity = velocity
49
+ self.value = value
50
+
51
+ def __repr__(self):
52
+ return '<[SNote] time: {} type: {}, value: {}, velocity: {}>'\
53
+ .format(self.time, self.type, self.value, self.velocity)
54
+
55
+
56
+ class Event:
57
+ def __init__(self, event_type, value):
58
+ self.type = event_type
59
+ self.value = value
60
+
61
+ def __repr__(self):
62
+ return '<Event type: {}, value: {}>'.format(self.type, self.value)
63
+
64
+ def to_int(self):
65
+ return START_IDX[self.type] + self.value
66
+
67
+ @staticmethod
68
+ def from_int(int_value):
69
+ info = Event._type_check(int_value)
70
+ return Event(info['type'], info['value'])
71
+
72
+ @staticmethod
73
+ def _type_check(int_value):
74
+ range_note_on = range(0, RANGE_NOTE_ON)
75
+ range_note_off = range(RANGE_NOTE_ON, RANGE_NOTE_ON+RANGE_NOTE_OFF)
76
+ range_time_shift = range(RANGE_NOTE_ON+RANGE_NOTE_OFF,RANGE_NOTE_ON+RANGE_NOTE_OFF+RANGE_TIME_SHIFT)
77
+
78
+ valid_value = int_value
79
+
80
+ if int_value in range_note_on:
81
+ return {'type': 'note_on', 'value': valid_value}
82
+ elif int_value in range_note_off:
83
+ valid_value -= RANGE_NOTE_ON
84
+ return {'type': 'note_off', 'value': valid_value}
85
+ elif int_value in range_time_shift:
86
+ valid_value -= (RANGE_NOTE_ON + RANGE_NOTE_OFF)
87
+ return {'type': 'time_shift', 'value': valid_value}
88
+ else:
89
+ valid_value -= (RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_TIME_SHIFT)
90
+ return {'type': 'velocity', 'value': valid_value}
91
+
92
+
93
+ def _divide_note(notes):
94
+ result_array = []
95
+ notes.sort(key=lambda x: x.start)
96
+
97
+ for note in notes:
98
+ on = SplitNote('note_on', note.start, note.pitch, note.velocity)
99
+ off = SplitNote('note_off', note.end, note.pitch, None)
100
+ result_array += [on, off]
101
+ return result_array
102
+
103
+
104
+ def _merge_note(snote_sequence):
105
+ note_on_dict = {}
106
+ result_array = []
107
+
108
+ for snote in snote_sequence:
109
+ # print(note_on_dict)
110
+ if snote.type == 'note_on':
111
+ note_on_dict[snote.value] = snote
112
+ elif snote.type == 'note_off':
113
+ try:
114
+ on = note_on_dict[snote.value]
115
+ off = snote
116
+ if off.time - on.time == 0:
117
+ continue
118
+ result = pretty_midi.Note(on.velocity, snote.value, on.time, off.time)
119
+ result_array.append(result)
120
+ except:
121
+ print('info removed pitch: {}'.format(snote.value))
122
+ return result_array
123
+
124
+
125
+ def _snote2events(snote: SplitNote, prev_vel: int):
126
+ result = []
127
+ if snote.velocity is not None:
128
+ modified_velocity = snote.velocity // 4
129
+ if prev_vel != modified_velocity:
130
+ result.append(Event(event_type='velocity', value=modified_velocity))
131
+ result.append(Event(event_type=snote.type, value=snote.value))
132
+ return result
133
+
134
+
135
+ def _event_seq2snote_seq(event_sequence):
136
+ timeline = 0
137
+ velocity = 0
138
+ snote_seq = []
139
+
140
+ for event in event_sequence:
141
+ if event.type == 'time_shift':
142
+ timeline += ((event.value+1) / 100)
143
+ if event.type == 'velocity':
144
+ velocity = event.value * 4
145
+ else:
146
+ snote = SplitNote(event.type, timeline, event.value, velocity)
147
+ snote_seq.append(snote)
148
+ return snote_seq
149
+
150
+
151
+ def _make_time_sift_events(prev_time, post_time):
152
+ time_interval = int(round((post_time - prev_time) * 100))
153
+ results = []
154
+ while time_interval >= RANGE_TIME_SHIFT:
155
+ results.append(Event(event_type='time_shift', value=RANGE_TIME_SHIFT-1))
156
+ time_interval -= RANGE_TIME_SHIFT
157
+ if time_interval == 0:
158
+ return results
159
+ else:
160
+ return results + [Event(event_type='time_shift', value=time_interval-1)]
161
+
162
+
163
+ def _control_preprocess(ctrl_changes):
164
+ sustains = []
165
+
166
+ manager = None
167
+ for ctrl in ctrl_changes:
168
+ if ctrl.value >= 64 and manager is None:
169
+ # sustain down
170
+ manager = SustainDownManager(start=ctrl.time, end=None)
171
+ elif ctrl.value < 64 and manager is not None:
172
+ # sustain up
173
+ manager.end = ctrl.time
174
+ sustains.append(manager)
175
+ manager = None
176
+ elif ctrl.value < 64 and len(sustains) > 0:
177
+ sustains[-1].end = ctrl.time
178
+ return sustains
179
+
180
+
181
+ def _note_preprocess(susteins, notes):
182
+ note_stream = []
183
+
184
+ if susteins: # if the midi file has sustain controls
185
+ for sustain in susteins:
186
+ for note_idx, note in enumerate(notes):
187
+ if note.start < sustain.start:
188
+ note_stream.append(note)
189
+ elif note.start > sustain.end:
190
+ notes = notes[note_idx:]
191
+ sustain.transposition_notes()
192
+ break
193
+ else:
194
+ sustain.add_managed_note(note)
195
+
196
+ for sustain in susteins:
197
+ note_stream += sustain.managed_notes
198
+
199
+ else: # else, just push everything into note stream
200
+ for note_idx, note in enumerate(notes):
201
+ note_stream.append(note)
202
+
203
+ note_stream.sort(key= lambda x: x.start)
204
+ return note_stream
205
+
206
+
207
+ def encode_midi(file_path):
208
+ events = []
209
+ notes = []
210
+ mid = pretty_midi.PrettyMIDI(midi_file=file_path)
211
+
212
+ for inst in mid.instruments:
213
+ inst_notes = inst.notes
214
+ # ctrl.number is the number of sustain control. If you want to know abour the number type of control,
215
+ # see https://www.midi.org/specifications-old/item/table-3-control-change-messages-data-bytes-2
216
+ ctrls = _control_preprocess([ctrl for ctrl in inst.control_changes if ctrl.number == 64])
217
+ notes += _note_preprocess(ctrls, inst_notes)
218
+
219
+ dnotes = _divide_note(notes)
220
+
221
+ # print(dnotes)
222
+ dnotes.sort(key=lambda x: x.time)
223
+ # print('sorted:')
224
+ # print(dnotes)
225
+ cur_time = 0
226
+ cur_vel = 0
227
+ for snote in dnotes:
228
+ events += _make_time_sift_events(prev_time=cur_time, post_time=snote.time)
229
+ events += _snote2events(snote=snote, prev_vel=cur_vel)
230
+ # events += _make_time_sift_events(prev_time=cur_time, post_time=snote.time)
231
+
232
+ cur_time = snote.time
233
+ cur_vel = snote.velocity
234
+
235
+ return [e.to_int() for e in events]
236
+
237
+
238
+ def decode_midi(idx_array, file_path=None):
239
+ event_sequence = [Event.from_int(idx) for idx in idx_array]
240
+ # print(event_sequence)
241
+ snote_seq = _event_seq2snote_seq(event_sequence)
242
+ note_seq = _merge_note(snote_seq)
243
+ note_seq.sort(key=lambda x:x.start)
244
+
245
+ mid = pretty_midi.PrettyMIDI()
246
+ # if want to change instument, see https://www.midi.org/specifications/item/gm-level-1-sound-set
247
+ instument = pretty_midi.Instrument(0, False, "Composed by Super Piano Music Transformer AI")
248
+ instument.notes = note_seq
249
+
250
+ mid.instruments.append(instument)
251
+ if file_path is not None:
252
+ mid.write(file_path)
253
+ return mid
254
+
255
+
256
+ if __name__ == '__main__':
257
+ encoded = encode_midi('bin/ADIG04.mid')
258
+ print(encoded)
259
+ decided = decode_midi(encoded,file_path='bin/test.mid')
260
+
261
+ ins = pretty_midi.PrettyMIDI('bin/ADIG04.mid')
262
+ print(ins)
263
+ print(ins.instruments[0])
264
+ for i in ins.instruments:
265
+ print(i.control_changes)
266
+ print(i.notes)
267
+
utilities/__init__.py ADDED
File without changes
utilities/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (155 Bytes). View file
 
utilities/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (159 Bytes). View file
 
utilities/__pycache__/argument_funcs.cpython-312.pyc ADDED
Binary file (13.4 kB). View file
 
utilities/__pycache__/constants.cpython-312.pyc ADDED
Binary file (840 Bytes). View file
 
utilities/__pycache__/constants.cpython-313.pyc ADDED
Binary file (844 Bytes). View file
 
utilities/__pycache__/device.cpython-312.pyc ADDED
Binary file (1.74 kB). View file
 
utilities/__pycache__/device.cpython-313.pyc ADDED
Binary file (1.66 kB). View file
 
utilities/argument_funcs.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from .constants import SEPERATOR
4
+
5
+ # parse_train_args
6
+ def parse_train_args():
7
+ """
8
+ ----------
9
+ Author: Damon Gwinn
10
+ ----------
11
+ Argparse arguments for training a model
12
+ ----------
13
+ """
14
+
15
+ parser = argparse.ArgumentParser()
16
+
17
+ parser.add_argument("-input_dir", type=str, default="./dataset/e_piano", help="Folder of preprocessed and pickled midi files")
18
+ parser.add_argument("-output_dir", type=str, default="./saved_models", help="Folder to save model weights. Saves one every epoch")
19
+ parser.add_argument("-weight_modulus", type=int, default=1, help="How often to save epoch weights (ex: value of 10 means save every 10 epochs)")
20
+ parser.add_argument("-print_modulus", type=int, default=1, help="How often to print train results for a batch (batch loss, learn rate, etc.)")
21
+
22
+ parser.add_argument("-n_workers", type=int, default=1, help="Number of threads for the dataloader")
23
+ parser.add_argument("--force_cpu", action="store_true", help="Forces model to run on a cpu even when gpu is available")
24
+ parser.add_argument("--no_tensorboard", action="store_true", help="Turns off tensorboard result reporting")
25
+
26
+ parser.add_argument("-continue_weights", type=str, default=None, help="Model weights to continue training based on")
27
+ parser.add_argument("-continue_epoch", type=int, default=None, help="Epoch the continue_weights model was at")
28
+
29
+ parser.add_argument("-lr", type=float, default=None, help="Constant learn rate. Leave as None for a custom scheduler.")
30
+ parser.add_argument("-ce_smoothing", type=float, default=None, help="Smoothing parameter for smoothed cross entropy loss (defaults to no smoothing)")
31
+ parser.add_argument("-batch_size", type=int, default=2, help="Batch size to use")
32
+ parser.add_argument("-epochs", type=int, default=100, help="Number of epochs to use")
33
+
34
+ parser.add_argument("--rpr", action="store_true", help="Use a modified Transformer for Relative Position Representations")
35
+ parser.add_argument("-max_sequence", type=int, default=2048, help="Maximum midi sequence to consider")
36
+ parser.add_argument("-n_layers", type=int, default=6, help="Number of decoder layers to use")
37
+ parser.add_argument("-num_heads", type=int, default=8, help="Number of heads to use for multi-head attention")
38
+ parser.add_argument("-d_model", type=int, default=512, help="Dimension of the model (output dim of embedding layers, etc.)")
39
+
40
+ parser.add_argument("-dim_feedforward", type=int, default=1024, help="Dimension of the feedforward layer")
41
+
42
+ parser.add_argument("-dropout", type=float, default=0.1, help="Dropout rate")
43
+
44
+ return parser.parse_args()
45
+
46
+ # print_train_args
47
+ def print_train_args(args):
48
+ """
49
+ ----------
50
+ Author: Damon Gwinn
51
+ ----------
52
+ Prints training arguments
53
+ ----------
54
+ """
55
+
56
+ print(SEPERATOR)
57
+ print("input_dir:", args.input_dir)
58
+ print("output_dir:", args.output_dir)
59
+ print("weight_modulus:", args.weight_modulus)
60
+ print("print_modulus:", args.print_modulus)
61
+ print("")
62
+ print("n_workers:", args.n_workers)
63
+ print("force_cpu:", args.force_cpu)
64
+ print("tensorboard:", not args.no_tensorboard)
65
+ print("")
66
+ print("continue_weights:", args.continue_weights)
67
+ print("continue_epoch:", args.continue_epoch)
68
+ print("")
69
+ print("lr:", args.lr)
70
+ print("ce_smoothing:", args.ce_smoothing)
71
+ print("batch_size:", args.batch_size)
72
+ print("epochs:", args.epochs)
73
+ print("")
74
+ print("rpr:", args.rpr)
75
+ print("max_sequence:", args.max_sequence)
76
+ print("n_layers:", args.n_layers)
77
+ print("num_heads:", args.num_heads)
78
+ print("d_model:", args.d_model)
79
+ print("")
80
+ print("dim_feedforward:", args.dim_feedforward)
81
+ print("dropout:", args.dropout)
82
+ print(SEPERATOR)
83
+ print("")
84
+
85
+ # parse_eval_args
86
+ def parse_eval_args():
87
+ """
88
+ ----------
89
+ Author: Damon Gwinn
90
+ ----------
91
+ Argparse arguments for evaluating a model
92
+ ----------
93
+ """
94
+
95
+ parser = argparse.ArgumentParser()
96
+
97
+ parser.add_argument("-dataset_dir", type=str, default="./dataset/e_piano", help="Folder of preprocessed and pickled midi files")
98
+ parser.add_argument("-model_weights", type=str, default="./saved_models/model.pickle", help="Pickled model weights file saved with torch.save and model.state_dict()")
99
+ parser.add_argument("-n_workers", type=int, default=1, help="Number of threads for the dataloader")
100
+ parser.add_argument("--force_cpu", action="store_true", help="Forces model to run on a cpu even when gpu is available")
101
+
102
+ parser.add_argument("-batch_size", type=int, default=2, help="Batch size to use")
103
+
104
+ parser.add_argument("--rpr", action="store_true", help="Use a modified Transformer for Relative Position Representations")
105
+ parser.add_argument("-max_sequence", type=int, default=2048, help="Maximum midi sequence to consider in the model")
106
+ parser.add_argument("-n_layers", type=int, default=6, help="Number of decoder layers to use")
107
+ parser.add_argument("-num_heads", type=int, default=8, help="Number of heads to use for multi-head attention")
108
+ parser.add_argument("-d_model", type=int, default=512, help="Dimension of the model (output dim of embedding layers, etc.)")
109
+
110
+ parser.add_argument("-dim_feedforward", type=int, default=1024, help="Dimension of the feedforward layer")
111
+
112
+ return parser.parse_args()
113
+
114
+ # print_eval_args
115
+ def print_eval_args(args):
116
+ """
117
+ ----------
118
+ Author: Damon Gwinn
119
+ ----------
120
+ Prints evaluation arguments
121
+ ----------
122
+ """
123
+
124
+ print(SEPERATOR)
125
+ print("dataset_dir:", args.dataset_dir)
126
+ print("model_weights:", args.model_weights)
127
+ print("n_workers:", args.n_workers)
128
+ print("force_cpu:", args.force_cpu)
129
+ print("")
130
+ print("batch_size:", args.batch_size)
131
+ print("")
132
+ print("rpr:", args.rpr)
133
+ print("max_sequence:", args.max_sequence)
134
+ print("n_layers:", args.n_layers)
135
+ print("num_heads:", args.num_heads)
136
+ print("d_model:", args.d_model)
137
+ print("")
138
+ print("dim_feedforward:", args.dim_feedforward)
139
+ print(SEPERATOR)
140
+ print("")
141
+
142
+ # parse_generate_args
143
+ def parse_generate_args():
144
+ """
145
+ ----------
146
+ Author: Damon Gwinn
147
+ ----------
148
+ Argparse arguments for generation
149
+ ----------
150
+ """
151
+
152
+ parser = argparse.ArgumentParser()
153
+
154
+ parser.add_argument("-midi_root", type=str, default="./dataset/e_piano/", help="Midi file to prime the generator with")
155
+ parser.add_argument("-output_dir", type=str, default="./gen", help="Folder to write generated midi to")
156
+ parser.add_argument("-primer_file", type=str, default=None, help="File path or integer index to the evaluation dataset. Default is to select a random index.")
157
+ parser.add_argument("--force_cpu", action="store_true", help="Forces model to run on a cpu even when gpu is available")
158
+
159
+ parser.add_argument("-target_seq_length", type=int, default=1024, help="Target length you'd like the midi to be")
160
+ parser.add_argument("-num_prime", type=int, default=256, help="Amount of messages to prime the generator with")
161
+ parser.add_argument("-model_weights", type=str, default="./saved_models/model.pickle", help="Pickled model weights file saved with torch.save and model.state_dict()")
162
+ parser.add_argument("-beam", type=int, default=0, help="Beam search k. 0 for random probability sample and 1 for greedy")
163
+
164
+ parser.add_argument("--rpr", action="store_true", help="Use a modified Transformer for Relative Position Representations")
165
+ parser.add_argument("-max_sequence", type=int, default=2048, help="Maximum midi sequence to consider")
166
+ parser.add_argument("-n_layers", type=int, default=6, help="Number of decoder layers to use")
167
+ parser.add_argument("-num_heads", type=int, default=8, help="Number of heads to use for multi-head attention")
168
+ parser.add_argument("-d_model", type=int, default=512, help="Dimension of the model (output dim of embedding layers, etc.)")
169
+
170
+ parser.add_argument("-dim_feedforward", type=int, default=1024, help="Dimension of the feedforward layer")
171
+
172
+ return parser.parse_args()
173
+
174
+ # print_generate_args
175
+ def print_generate_args(args):
176
+ """
177
+ ----------
178
+ Author: Damon Gwinn
179
+ ----------
180
+ Prints generation arguments
181
+ ----------
182
+ """
183
+
184
+ print(SEPERATOR)
185
+ print("midi_root:", args.midi_root)
186
+ print("output_dir:", args.output_dir)
187
+ print("primer_file:", args.primer_file)
188
+ print("force_cpu:", args.force_cpu)
189
+ print("")
190
+ print("target_seq_length:", args.target_seq_length)
191
+ print("num_prime:", args.num_prime)
192
+ print("model_weights:", args.model_weights)
193
+ print("beam:", args.beam)
194
+ print("")
195
+ print("rpr:", args.rpr)
196
+ print("max_sequence:", args.max_sequence)
197
+ print("n_layers:", args.n_layers)
198
+ print("num_heads:", args.num_heads)
199
+ print("d_model:", args.d_model)
200
+ print("")
201
+ print("dim_feedforward:", args.dim_feedforward)
202
+ print(SEPERATOR)
203
+ print("")
204
+
205
+ # write_model_params
206
+ def write_model_params(args, output_file):
207
+ """
208
+ ----------
209
+ Author: Damon Gwinn
210
+ ----------
211
+ Writes given training parameters to text file
212
+ ----------
213
+ """
214
+
215
+ o_stream = open(output_file, "w")
216
+
217
+ o_stream.write("rpr: " + str(args.rpr) + "\n")
218
+ o_stream.write("lr: " + str(args.lr) + "\n")
219
+ o_stream.write("ce_smoothing: " + str(args.ce_smoothing) + "\n")
220
+ o_stream.write("batch_size: " + str(args.batch_size) + "\n")
221
+ o_stream.write("max_sequence: " + str(args.max_sequence) + "\n")
222
+ o_stream.write("n_layers: " + str(args.n_layers) + "\n")
223
+ o_stream.write("num_heads: " + str(args.num_heads) + "\n")
224
+ o_stream.write("d_model: " + str(args.d_model) + "\n")
225
+ o_stream.write("dim_feedforward: " + str(args.dim_feedforward) + "\n")
226
+ o_stream.write("dropout: " + str(args.dropout) + "\n")
227
+
228
+ o_stream.close()
utilities/constants.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from processor import RANGE_NOTE_ON, RANGE_NOTE_OFF, RANGE_VEL, RANGE_TIME_SHIFT
4
+
5
+ SEPERATOR = "========================="
6
+
7
+ # Taken from the paper
8
+ ADAM_BETA_1 = 0.9
9
+ ADAM_BETA_2 = 0.98
10
+ ADAM_EPSILON = 10e-9
11
+
12
+ LR_DEFAULT_START = 1.0
13
+ SCHEDULER_WARMUP_STEPS = 4000
14
+ # LABEL_SMOOTHING_E = 0.1
15
+
16
+ # DROPOUT_P = 0.1
17
+
18
+ TOKEN_END = RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_VEL + RANGE_TIME_SHIFT
19
+ TOKEN_PAD = TOKEN_END + 1
20
+
21
+ VOCAB_SIZE = TOKEN_PAD + 1
22
+
23
+ TORCH_FLOAT = torch.float32
24
+ TORCH_INT = torch.int32
25
+
26
+ TORCH_LABEL_TYPE = torch.long
27
+
28
+ PREPEND_ZEROS_WIDTH = 4
utilities/device.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # For all things related to devices
2
+ #### ONLY USE PROVIDED FUNCTIONS, DO NOT USE GLOBAL CONSTANTS ####
3
+
4
+ import torch
5
+
6
+ TORCH_CPU_DEVICE = torch.device("cpu")
7
+
8
+ if(torch.cuda.device_count() > 0):
9
+ TORCH_CUDA_DEVICE = torch.device("cuda")
10
+ else:
11
+ print("----- WARNING: CUDA devices not detected. This will cause the model to run very slow! -----")
12
+ print("")
13
+ TORCH_CUDA_DEVICE = None
14
+
15
+ USE_CUDA = True
16
+
17
+ # use_cuda
18
+ def use_cuda(cuda_bool):
19
+ """
20
+ ----------
21
+ Author: Damon Gwinn
22
+ ----------
23
+ Sets whether to use CUDA (if available), or use the CPU (not recommended)
24
+ ----------
25
+ """
26
+
27
+ global USE_CUDA
28
+ USE_CUDA = cuda_bool
29
+
30
+ # get_device
31
+ def get_device():
32
+ """
33
+ ----------
34
+ Author: Damon Gwinn
35
+ ----------
36
+ Grabs the default device. Default device is CUDA if available and use_cuda is not False, CPU otherwise.
37
+ ----------
38
+ """
39
+
40
+ if((not USE_CUDA) or (TORCH_CUDA_DEVICE is None)):
41
+ return TORCH_CPU_DEVICE
42
+ else:
43
+ return TORCH_CUDA_DEVICE
44
+
45
+ # cuda_device
46
+ def cuda_device():
47
+ """
48
+ ----------
49
+ Author: Damon Gwinn
50
+ ----------
51
+ Grabs the cuda device (may be None if CUDA is not available)
52
+ ----------
53
+ """
54
+
55
+ return TORCH_CUDA_DEVICE
56
+
57
+ # cpu_device
58
+ def cpu_device():
59
+ """
60
+ ----------
61
+ Author: Damon Gwinn
62
+ ----------
63
+ Grabs the cpu device
64
+ ----------
65
+ """
66
+
67
+ return TORCH_CPU_DEVICE