adricl commited on
Commit
c20cfc8
·
1 Parent(s): 94d5306

Added new traing data

Browse files
.ipynb_checkpoints/HuggingFace_Mistral_Transformer_Single_Instrument-checkpoint.ipynb ADDED
@@ -0,0 +1,803 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "attachments": {},
5
+ "cell_type": "markdown",
6
+ "metadata": {
7
+ "id": "SiTIpPjArIyr"
8
+ },
9
+ "source": [
10
+ "# Using Midi traning data and MidiTok Remi to generate music with Mistral model \n",
11
+ "# split music into Single Instrument and split into 1024\n"
12
+ ]
13
+ },
14
+ {
15
+ "attachments": {},
16
+ "cell_type": "markdown",
17
+ "metadata": {
18
+ "id": "gOd93yV0sGd2"
19
+ },
20
+ "source": [
21
+ "## Setup Environment"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": null,
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "To compile Symusic \n",
31
+ "\n",
32
+ "Get g++11 or higher\n",
33
+ "\n",
34
+ "git clone --recursive https://github.com/Yikai-Liao/symusic\n",
35
+ "CXX=/usr/bin/g++-11 pip install ./symusic\n"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "%pip install torch==2.6.0\n",
45
+ "%pip install evaluate transformers[torch]==4.55.4 tqdm miditok accelerate tensorboardX scikit-learn\n"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": 1,
51
+ "metadata": {
52
+ "cellView": "form",
53
+ "id": "fX12Yquyuihc"
54
+ },
55
+ "outputs": [
56
+ {
57
+ "name": "stderr",
58
+ "output_type": "stream",
59
+ "text": [
60
+ "2025-09-12 09:17:37.410013: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
61
+ "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
62
+ "2025-09-12 09:17:38.509451: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
63
+ ]
64
+ }
65
+ ],
66
+ "source": [
67
+ "\n",
68
+ "\n",
69
+ "from copy import deepcopy\n",
70
+ "from pathlib import Path\n",
71
+ "from random import shuffle, sample\n",
72
+ "\n",
73
+ "from evaluate import load as load_metric\n",
74
+ "from miditok import REMI, TokenizerConfig, TokTrainingIterator\n",
75
+ "from miditok.pytorch_data import DatasetMIDI, DataCollator\n",
76
+ "from miditok.utils import split_files_for_training\n",
77
+ "\n",
78
+ "from miditok.data_augmentation import augment_dataset\n",
79
+ "from torch import Tensor, argmax, torch\n",
80
+ "from torch.utils.data import DataLoader\n",
81
+ "from torch.cuda import is_available as cuda_available, is_bf16_supported\n",
82
+ "from torch.backends.mps import is_available as mps_available\n",
83
+ "from transformers import AutoModelForCausalLM, MistralConfig, Trainer, TrainingArguments, GenerationConfig, AutoConfig\n",
84
+ "from transformers.trainer_utils import set_seed\n",
85
+ "from tqdm import tqdm"
86
+ ]
87
+ },
88
+ {
89
+ "attachments": {},
90
+ "cell_type": "markdown",
91
+ "metadata": {},
92
+ "source": [
93
+ "## Setup Tokenizer"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "execution_count": 2,
99
+ "metadata": {},
100
+ "outputs": [],
101
+ "source": [
102
+ "# Seed\n",
103
+ "set_seed(777)\n",
104
+ "\n",
105
+ "# Our tokenizer's configuration\n",
106
+ "BEAT_RES = {(0, 1): 12, (1, 2): 4, (2, 4): 2, (4, 8): 1}\n",
107
+ "TOKENIZER_PARAMS = {\n",
108
+ " \"pitch_range\": (21, 108),\n",
109
+ " \"beat_res\": BEAT_RES,\n",
110
+ " \"num_velocities\": 32,\n",
111
+ " \"special_tokens\": [\"PAD\", \"BOS\", \"EOS\"],\n",
112
+ " \"use_chords\": True,\n",
113
+ " \"use_rests\": True,\n",
114
+ " \"use_tempos\": True,\n",
115
+ " \"use_time_signatures\": True,\n",
116
+ " \"use_programs\": False, # We want single track \n",
117
+ " \"one_token_stream_for_programs\": False, # We want single track\n",
118
+ " \"programs\": list(range(0, 128)), #-1 drums, skip drums\n",
119
+ " \"num_tempos\": 32,\n",
120
+ " \"tempo_range\": (40, 250), # (min_tempo, max_tempo)\n",
121
+ "}\n",
122
+ "config = TokenizerConfig(**TOKENIZER_PARAMS)\n",
123
+ "\n",
124
+ "# Creates the tokenizer REMI PLUS\n",
125
+ "tokenizer = REMI(config)"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "markdown",
130
+ "metadata": {},
131
+ "source": [
132
+ "# Load Midi filed and train the the tokenizer on the midi files"
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "code",
137
+ "execution_count": 3,
138
+ "metadata": {},
139
+ "outputs": [],
140
+ "source": [
141
+ "root_data_dir = Path('/home/wombat/Documents/projects/music/midiTok/data/')\n",
142
+ "root_save = Path(root_data_dir / 'HuggingFace_Mistral_Transformer_Single_Instrument')\n",
143
+ "\n",
144
+ "tokenizer_name = \"HuggingFace_Mistral_Transformer_Single_Instrument_v4_single_track.json\"\n",
145
+ "dataset_dir = root_save / \"data\"\n",
146
+ "dataset_dir.mkdir(parents=True, exist_ok=True)"
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "execution_count": null,
152
+ "metadata": {},
153
+ "outputs": [],
154
+ "source": [
155
+ "\n",
156
+ "# Trains the tokenizer with Byte Pair Encoding (BPE) to build the vocabulary, here 30k tokens\n",
157
+ "#data_dirs = [\"adl-piano-midi\", \"maestro-v3.0.0\", \"musicnet_midis\" ] # for single \n",
158
+ "data_dirs = [\"MIDIs\"]\n",
159
+ "midi_paths = []\n",
160
+ "for data_dir in data_dirs:\n",
161
+ " path = Path(root_data_dir / 'Traning Data' / data_dir)\n",
162
+ " midi_paths.extend(list(path.resolve().glob(\"**/*.mid\")) + list(path.resolve().glob(\"**/*.midi\")))\n",
163
+ "\n",
164
+ "print(f\"Found {len(midi_paths)} MIDI files\")\n",
165
+ "\n",
166
+ "shuffle(midi_paths)\n",
167
+ "\n",
168
+ "# We need a subset of files otherwise training tokenizer takes too long\n",
169
+ "percentage_to_select = 0.15\n",
170
+ "num_files_to_select = int(len(midi_paths) * percentage_to_select)\n",
171
+ "\n",
172
+ "subset_midi_paths = sample(midi_paths, num_files_to_select)\n",
173
+ "print(f\"Found {len(subset_midi_paths)} MIDI files\")"
174
+ ]
175
+ },
176
+ {
177
+ "cell_type": "code",
178
+ "execution_count": null,
179
+ "metadata": {},
180
+ "outputs": [],
181
+ "source": [
182
+ "#Note the size of the dataset is quite large, so it requires a huge amount of memory to train the tokenizer for 61749 files it took 64gb of memory\n",
183
+ "tokenizer.train(\n",
184
+ " vocab_size=24000,\n",
185
+ " files_paths=subset_midi_paths,\n",
186
+ ")\n",
187
+ "tokenizer.save(root_save / tokenizer_name)\n",
188
+ "\n"
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "code",
193
+ "execution_count": 4,
194
+ "metadata": {},
195
+ "outputs": [],
196
+ "source": [
197
+ "tokenizer = REMI(params=Path(root_save / tokenizer_name))\n"
198
+ ]
199
+ },
200
+ {
201
+ "cell_type": "markdown",
202
+ "metadata": {},
203
+ "source": [
204
+ "## Prepare MIDIs for training\n",
205
+ "\n",
206
+ "Here we split the files in three subsets: train, validation and test.\n",
207
+ "Then data augmentation is performed on each subset independently, and the MIDIs are split into smaller chunks that make approximately the desired token sequence length for training."
208
+ ]
209
+ },
210
+ {
211
+ "cell_type": "code",
212
+ "execution_count": 5,
213
+ "metadata": {},
214
+ "outputs": [],
215
+ "source": [
216
+ "sequence_length = 1024 # The maximum sequence length for data samples.\n",
217
+ "kwargs_dataset = {\"max_seq_len\": sequence_length, \"tokenizer\": tokenizer, \"bos_token_id\": tokenizer[\"BOS_None\"], \"eos_token_id\": tokenizer[\"EOS_None\"], \"pre_tokenize\": True, \"pre_tokenize_thread_count\": 7}"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "markdown",
222
+ "metadata": {},
223
+ "source": [
224
+ "# Test splitting files for training and testing purposes"
225
+ ]
226
+ },
227
+ {
228
+ "cell_type": "code",
229
+ "execution_count": null,
230
+ "metadata": {},
231
+ "outputs": [],
232
+ "source": [
233
+ "from pathlib import Path\n",
234
+ "# Split will need to add the BPM to the files its split\n",
235
+ "# \n",
236
+ "file_paths_test = [\n",
237
+ " Path('/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Fatboy Slim/Right Here, Right Now.mid'),\n",
238
+ " Path('/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Fatboy Slim/Praise You.mid'),\n",
239
+ " Path('/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Goo Goo Dolls/Iris.mid'),\n",
240
+ " Path('/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Goo Goo Dolls/Slide.mid'),\n",
241
+ " Path('/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/James Brown/Sex Machine (Get Up I Feel Like Being A).mid'),\n",
242
+ " Path('/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Jamiroquai/Virtual Insanity.1.mid'),\n",
243
+ " Path('/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Jamiroquai/Virtual Insanity.mid')\n",
244
+ "]\n",
245
+ "\n",
246
+ "split_files_for_training(\n",
247
+ " files_paths=file_paths_test,\n",
248
+ " tokenizer=tokenizer,\n",
249
+ " save_dir=Path('/home/wombat/Documents/projects/music/midiTok/data/HuggingFace_Mistral_Transformer_Single_Instrument/test'),\n",
250
+ " max_seq_len=sequence_length,\n",
251
+ " num_overlap_bars=2,\n",
252
+ " skip_drums=True\n",
253
+ ")"
254
+ ]
255
+ },
256
+ {
257
+ "cell_type": "code",
258
+ "execution_count": null,
259
+ "metadata": {},
260
+ "outputs": [],
261
+ "source": [
262
+ "# Split MIDI paths in train/valid/test sets\n",
263
+ "total_num_files = len(midi_paths)\n",
264
+ "\n",
265
+ "num_files_valid = round(total_num_files * 0.15)\n",
266
+ "num_files_test = round(total_num_files * 0.15)\n",
267
+ "shuffle(midi_paths)\n",
268
+ "midi_paths_valid = midi_paths[:num_files_valid]\n",
269
+ "midi_paths_test = midi_paths[num_files_valid:num_files_valid + num_files_test]\n",
270
+ "midi_paths_train = midi_paths[num_files_valid + num_files_test:]\n",
271
+ "\n",
272
+ "\n",
273
+ "\n",
274
+ "# Chunk MIDIs and perform data augmentation on each subset independently\n",
275
+ "for files_paths, subset_name in (\n",
276
+ " (midi_paths_train, \"train\"), (midi_paths_valid, \"valid\"), (midi_paths_test, \"test\")\n",
277
+ "):\n",
278
+ "\n",
279
+ " # Split the MIDIs into chunks of sizes approximately about 1024 tokens\n",
280
+ " subset_chunks_dir = root_save / f\"Maestro_{subset_name}\"\n",
281
+ " print(subset_chunks_dir)\n",
282
+ " split_files_for_training(\n",
283
+ " files_paths=files_paths,\n",
284
+ " tokenizer=tokenizer,\n",
285
+ " save_dir=subset_chunks_dir,\n",
286
+ " max_seq_len=sequence_length,\n",
287
+ " num_overlap_bars=2,\n",
288
+ " skip_drums=True\n",
289
+ " )\n",
290
+ "\n",
291
+ " if subset_name == 'train':\n",
292
+ " print(\"Augmentation\")\n",
293
+ " # Perform data augmentation\n",
294
+ " augment_dataset(\n",
295
+ " subset_chunks_dir,\n",
296
+ " pitch_offsets=[-12, 12],\n",
297
+ " velocity_offsets=[-4, 4],\n",
298
+ " duration_offsets=[-0.5, 0.5],\n",
299
+ " )\n"
300
+ ]
301
+ },
302
+ {
303
+ "cell_type": "code",
304
+ "execution_count": 6,
305
+ "metadata": {},
306
+ "outputs": [],
307
+ "source": [
308
+ "#Since the datasets are too large after splitting we only want 50% of the split data to train against\n",
309
+ "sample_subset_per = .5"
310
+ ]
311
+ },
312
+ {
313
+ "cell_type": "code",
314
+ "execution_count": 8,
315
+ "metadata": {},
316
+ "outputs": [
317
+ {
318
+ "ename": "TypeError",
319
+ "evalue": "slice indices must be integers or None or have an __index__ method",
320
+ "output_type": "error",
321
+ "traceback": [
322
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
323
+ "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
324
+ "Cell \u001b[0;32mIn[8], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Create Dataset and Collator for training\u001b[39;00m\n\u001b[1;32m 2\u001b[0m midi_paths_train \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(root_save\u001b[38;5;241m.\u001b[39mjoinpath(Path(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMaestro_train\u001b[39m\u001b[38;5;124m\"\u001b[39m))\u001b[38;5;241m.\u001b[39mglob(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m**/*.mid\u001b[39m\u001b[38;5;124m\"\u001b[39m)) \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mlist\u001b[39m(root_save\u001b[38;5;241m.\u001b[39mjoinpath(Path(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMaestro_train\u001b[39m\u001b[38;5;124m\"\u001b[39m))\u001b[38;5;241m.\u001b[39mglob(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m**/*.midi\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[0;32m----> 3\u001b[0m midi_paths_train \u001b[38;5;241m=\u001b[39m \u001b[43mmidi_paths_train\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mmidi_paths_train\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43msample_subset_per\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;28mlen\u001b[39m(midi_paths_train))\n\u001b[1;32m 5\u001b[0m dataset_train \u001b[38;5;241m=\u001b[39m DatasetMIDI(midi_paths_train, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs_dataset)\n",
325
+ "\u001b[0;31mTypeError\u001b[0m: slice indices must be integers or None or have an __index__ method"
326
+ ]
327
+ }
328
+ ],
329
+ "source": [
330
+ "# Create Dataset and Collator for training\n",
331
+ "midi_paths_train = list(root_save.joinpath(Path(\"Maestro_train\")).glob(\"**/*.mid\")) + list(root_save.joinpath(Path(\"Maestro_train\")).glob(\"**/*.midi\"))\n",
332
+ "sample_count = (len(midi_paths_train)*sample_subset_per)\n",
333
+ "midi_paths_train = midi_paths_train[:]\n",
334
+ "print(len(midi_paths_train))\n",
335
+ "dataset_train = DatasetMIDI(midi_paths_train, **kwargs_dataset)\n",
336
+ "torch.save(dataset_train, Path(dataset_dir / \"dataset_train.pt\"))"
337
+ ]
338
+ },
339
+ {
340
+ "cell_type": "code",
341
+ "execution_count": null,
342
+ "metadata": {},
343
+ "outputs": [],
344
+ "source": [
345
+ "midi_paths_valid = list(root_save.joinpath(Path(\"Maestro_valid\")).glob(\"**/*.mid\")) + list(root_save.joinpath(Path(\"Maestro_valid\")).glob(\"**/*.midi\")) \n",
346
+ "midi_paths_valid = midi_paths_valid[:(len(midi_paths_valid)*sample_subset_per]\n",
347
+ "print(len(midi_paths_valid))\n",
348
+ "dataset_valid = DatasetMIDI(midi_paths_valid, **kwargs_dataset)\n",
349
+ "torch.save(dataset_valid, Path(dataset_dir / \"dataset_valid.pt\"))"
350
+ ]
351
+ },
352
+ {
353
+ "cell_type": "code",
354
+ "execution_count": null,
355
+ "metadata": {},
356
+ "outputs": [],
357
+ "source": [
358
+ "midi_paths_test = list(root_save.joinpath(Path(\"Maestro_test\")).glob(\"**/*.mid\")) + list(root_save.joinpath(Path(\"Maestro_test\")).glob(\"**/*.midi\"))\n",
359
+ "midi_paths_test = midi_paths_test[:(len(midi_paths_test)*sample_subset_per]\n",
360
+ "print(len(midi_paths_test))\n",
361
+ "dataset_test = DatasetMIDI(midi_paths_test, **kwargs_dataset)\n",
362
+ "torch.save(dataset_test, Path(dataset_dir / \"dataset_test.pt\"))\n"
363
+ ]
364
+ },
365
+ {
366
+ "cell_type": "code",
367
+ "execution_count": null,
368
+ "metadata": {},
369
+ "outputs": [],
370
+ "source": [
371
+ "print (len(midi_paths_train), len(midi_paths_valid), len(midi_paths_test))\n"
372
+ ]
373
+ },
374
+ {
375
+ "cell_type": "markdown",
376
+ "metadata": {},
377
+ "source": [
378
+ "# Save and Load datasets"
379
+ ]
380
+ },
381
+ {
382
+ "cell_type": "code",
383
+ "execution_count": null,
384
+ "metadata": {},
385
+ "outputs": [],
386
+ "source": [
387
+ "\n",
388
+ "dataset_train = torch.load(Path(dataset_dir / \"dataset_train.pt\"), weights_only=False)\n",
389
+ "dataset_valid = torch.load(Path(dataset_dir / \"dataset_valid.pt\"), weights_only=False)\n",
390
+ "dataset_test = torch.load(Path(dataset_dir / \"dataset_test.pt\"), weights_only=False)\n",
391
+ "\n"
392
+ ]
393
+ },
394
+ {
395
+ "cell_type": "code",
396
+ "execution_count": null,
397
+ "metadata": {},
398
+ "outputs": [],
399
+ "source": [
400
+ "import pickle\n",
401
+ "\n",
402
+ "test_file = open(Path(dataset_dir / \"dataset_test.pickle\"), 'ab')\n",
403
+ "pickle.dump(dataset_test, test_file)\n",
404
+ "test_file.close()\n",
405
+ "\n",
406
+ "print(dataset_test[0])\n",
407
+ "\n",
408
+ "test_file = open(Path(dataset_dir / \"dataset_test.pickle\"), 'rb')\n",
409
+ "test_pickle = pickle.load(test_file)\n",
410
+ "print(test_pickle)\n",
411
+ "print(test_pickle[0])\n",
412
+ "\n",
413
+ "\n"
414
+ ]
415
+ },
416
+ {
417
+ "cell_type": "markdown",
418
+ "metadata": {},
419
+ "source": [
420
+ "# Preview files data load and split"
421
+ ]
422
+ },
423
+ {
424
+ "cell_type": "code",
425
+ "execution_count": null,
426
+ "metadata": {},
427
+ "outputs": [],
428
+ "source": [
429
+ "\n",
430
+ "#testing_files = \n",
431
+ "preview_files_path = []\n",
432
+ "for testing_file in testing_files:\n",
433
+ " preview_files_path.append(Path(testing_file))\n",
434
+ "\n",
435
+ "preview_dir = Path(root_save / \"preview\")\n",
436
+ "split_files_for_training(\n",
437
+ " files_paths=preview_files_path,\n",
438
+ " tokenizer=tokenizer,\n",
439
+ " save_dir=preview_dir,\n",
440
+ " max_seq_len=sequence_length,\n",
441
+ " num_overlap_bars=2,\n",
442
+ " )\n",
443
+ "\n",
444
+ "valid_midi_path = root_save / \"Maestro_valid\"\n",
445
+ "midi_split_preview = list(valid_midi_path.resolve().glob(\"**/*.mid\")) + list(valid_midi_path.resolve().glob(\"**/*.midi\"))\n",
446
+ "\n",
447
+ "print(len(midi_split_preview))\n",
448
+ "file_name_lookup = []\n",
449
+ "def func_to_get_labels(p1, p2, p3):\n",
450
+ " if p3.name not in file_name_lookup:\n",
451
+ " file_name_lookup.append(p3.name)\n",
452
+ " return file_name_lookup.index(p3.name)\n",
453
+ " \n",
454
+ "kwargs_dataset = {\"max_seq_len\": sequence_length, \"tokenizer\": tokenizer, \"bos_token_id\": tokenizer[\"BOS_None\"], \"eos_token_id\": tokenizer[\"EOS_None\"], \"func_to_get_labels\" : func_to_get_labels}\n",
455
+ "dataset_preview = DatasetMIDI(midi_split_preview, **kwargs_dataset)"
456
+ ]
457
+ },
458
+ {
459
+ "attachments": {},
460
+ "cell_type": "markdown",
461
+ "metadata": {},
462
+ "source": [
463
+ "## Model initialization\n",
464
+ "\n",
465
+ "We will use the [Mistral implementation of Hugging Face](https://huggingface.co/docs/transformers/model_doc/mistral).\n",
466
+ "Feel free to explore the documentation and source code to dig deeper.\n",
467
+ "\n",
468
+ "**You may need to adjust the model's configuration, the training configuration and the maximum input sequence length (cell above) depending on your hardware.**"
469
+ ]
470
+ },
471
+ {
472
+ "cell_type": "code",
473
+ "execution_count": null,
474
+ "metadata": {},
475
+ "outputs": [],
476
+ "source": [
477
+ "# Creates model\n",
478
+ "model_config = MistralConfig(\n",
479
+ " vocab_size=len(tokenizer), #from miditok output default 32K\n",
480
+ " hidden_size=512, # default 4096\n",
481
+ " intermediate_size=2048, # default 14336\n",
482
+ " num_hidden_layers=8, # default 32\n",
483
+ " num_attention_heads=8, # default 32\n",
484
+ " num_key_value_heads=4, # default 8\n",
485
+ " sliding_window=256, # default 4096\n",
486
+ " max_position_embeddings=8192, #has no effect on the parms count or training just limits the input length # default 4096*32\n",
487
+ " pad_token_id=tokenizer['PAD_None'],\n",
488
+ " bos_token_id=tokenizer['BOS_None'],\n",
489
+ " eos_token_id=tokenizer['EOS_None'],\n",
490
+ ")\n",
491
+ "model = AutoModelForCausalLM.from_config(model_config)"
492
+ ]
493
+ },
494
+ {
495
+ "attachments": {},
496
+ "cell_type": "markdown",
497
+ "metadata": {},
498
+ "source": [
499
+ "## Model training"
500
+ ]
501
+ },
502
+ {
503
+ "cell_type": "code",
504
+ "execution_count": null,
505
+ "metadata": {},
506
+ "outputs": [],
507
+ "source": [
508
+ "model_dir = root_save / 'run'\n",
509
+ "model_dir_str = str(model_dir)\n",
510
+ "print(model_dir)"
511
+ ]
512
+ },
513
+ {
514
+ "cell_type": "code",
515
+ "execution_count": null,
516
+ "metadata": {},
517
+ "outputs": [],
518
+ "source": [
519
+ "metrics = {metric: load_metric(metric) for metric in [\"accuracy\"]}\n",
520
+ "\n",
521
+ "def compute_metrics(eval_pred):\n",
522
+ " \"\"\"\n",
523
+ " Compute metrics for pretraining.\n",
524
+ "\n",
525
+ " Must use preprocess_logits function that converts logits to predictions (argmax or sampling).\n",
526
+ "\n",
527
+ " :param eval_pred: EvalPrediction containing predictions and labels\n",
528
+ " :return: metrics\n",
529
+ " \"\"\"\n",
530
+ " predictions, labels = eval_pred\n",
531
+ " not_pad_mask = labels != -100\n",
532
+ " labels, predictions = labels[not_pad_mask], predictions[not_pad_mask]\n",
533
+ " return metrics[\"accuracy\"].compute(predictions=predictions.flatten(), references=labels.flatten())\n",
534
+ "\n",
535
+ "def preprocess_logits(logits: Tensor, _: Tensor) -> Tensor:\n",
536
+ " \"\"\"\n",
537
+ " Preprocess the logits before accumulating them during evaluation.\n",
538
+ "\n",
539
+ " This allows to significantly reduce the memory usage and make the training tractable.\n",
540
+ " \"\"\"\n",
541
+ " pred_ids = argmax(logits, dim=-1) # long dtype\n",
542
+ " return pred_ids\n",
543
+ "\n",
544
+ "# Create config for the Trainer\n",
545
+ "USE_CUDA = cuda_available()\n",
546
+ "print(USE_CUDA)\n",
547
+ "if not cuda_available():\n",
548
+ " FP16 = FP16_EVAL = BF16 = BF16_EVAL = False\n",
549
+ "elif is_bf16_supported():\n",
550
+ " BF16 = BF16_EVAL = True\n",
551
+ " FP16 = FP16_EVAL = False\n",
552
+ "else:\n",
553
+ " BF16 = BF16_EVAL = False\n",
554
+ " FP16 = FP16_EVAL = True\n",
555
+ "USE_MPS = not USE_CUDA and mps_available()\n",
556
+ "training_config = TrainingArguments(\n",
557
+ " model_dir_str, False, True, True, False, \"steps\",\n",
558
+ " per_device_train_batch_size=24, #76% @ 24 batch size #76% @ 32 batch size try 64 batch size next time \n",
559
+ " per_device_eval_batch_size=24, #was 24 now 32\n",
560
+ " gradient_accumulation_steps=3, #change this to 4\n",
561
+ " eval_accumulation_steps=None,\n",
562
+ " eval_steps=1000,\n",
563
+ " learning_rate=1e-4,\n",
564
+ " weight_decay=0.01,\n",
565
+ " max_grad_norm=3.0,\n",
566
+ " max_steps=40000,\n",
567
+ " lr_scheduler_type=\"cosine_with_restarts\",\n",
568
+ " warmup_ratio=0.3,\n",
569
+ " log_level=\"debug\",\n",
570
+ " logging_strategy=\"steps\",\n",
571
+ " logging_steps=20,\n",
572
+ " save_strategy=\"steps\",\n",
573
+ " save_steps=1000,\n",
574
+ " save_total_limit=5,\n",
575
+ " no_cuda=not USE_CUDA,\n",
576
+ " seed=444,\n",
577
+ " fp16=FP16,\n",
578
+ " fp16_full_eval=FP16_EVAL,\n",
579
+ " bf16=BF16,\n",
580
+ " bf16_full_eval=BF16_EVAL,\n",
581
+ " load_best_model_at_end=True,\n",
582
+ " label_smoothing_factor=0.,\n",
583
+ " optim=\"adamw_torch\",\n",
584
+ " report_to=[\"tensorboard\"],\n",
585
+ " gradient_checkpointing=True,\n",
586
+ " dataloader_num_workers=8, #added to fix trashing isssue with the gpu not having enough data to process\n",
587
+ " dataloader_pin_memory=True, #we want the dataset in memory\n",
588
+ " torch_compile=True #added to speed up \n",
589
+ " \n",
590
+ ")\n",
591
+ "\n",
592
+ "collator = DataCollator(tokenizer[\"PAD_None\"], copy_inputs_as_labels=True, pad_on_left=True) #not sure about the pad_on_left, it might get better results\n",
593
+ "trainer = Trainer(\n",
594
+ " model=model,\n",
595
+ " args=training_config,\n",
596
+ " data_collator=collator,\n",
597
+ " train_dataset=dataset_train,\n",
598
+ " eval_dataset=dataset_valid,\n",
599
+ " compute_metrics=compute_metrics,\n",
600
+ " callbacks=None,\n",
601
+ " preprocess_logits_for_metrics=preprocess_logits,\n",
602
+ " \n",
603
+ ")\n",
604
+ "\n"
605
+ ]
606
+ },
607
+ {
608
+ "cell_type": "code",
609
+ "execution_count": null,
610
+ "metadata": {},
611
+ "outputs": [],
612
+ "source": [
613
+ "del model\n",
614
+ "torch.cuda.empty_cache()"
615
+ ]
616
+ },
617
+ {
618
+ "cell_type": "code",
619
+ "execution_count": null,
620
+ "metadata": {},
621
+ "outputs": [],
622
+ "source": [
623
+ "print(model)\n"
624
+ ]
625
+ },
626
+ {
627
+ "cell_type": "code",
628
+ "execution_count": null,
629
+ "metadata": {},
630
+ "outputs": [],
631
+ "source": [
632
+ "%env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
633
+ ]
634
+ },
635
+ {
636
+ "cell_type": "code",
637
+ "execution_count": null,
638
+ "metadata": {},
639
+ "outputs": [],
640
+ "source": [
641
+ "# Training\n",
642
+ "train_result = trainer.train()\n",
643
+ "trainer.save_model() # Saves the tokenizer too\n",
644
+ "trainer.log_metrics(\"train\", train_result.metrics)\n",
645
+ "trainer.save_metrics(\"train\", train_result.metrics)\n",
646
+ "trainer.save_state()"
647
+ ]
648
+ },
649
+ {
650
+ "cell_type": "code",
651
+ "execution_count": null,
652
+ "metadata": {},
653
+ "outputs": [],
654
+ "source": [
655
+ "model.create_model_card(tags=[\"mistral\", \"midi\", \"miditok\", \"music\", \"instrument\"],\n",
656
+ " model_name=\"Mistral_MidiTok_Transformer_Single_Instrument_Small\")"
657
+ ]
658
+ },
659
+ {
660
+ "cell_type": "code",
661
+ "execution_count": null,
662
+ "metadata": {},
663
+ "outputs": [],
664
+ "source": [
665
+ "\n",
666
+ "model.hub_model_id = \"adricl/midi_single_instrument_mistral_transformer\"\n",
667
+ "\n",
668
+ "model.push_to_hub(commit_message=\"Training Basic Model for Mistral MidiTok Transformer Single Instrument Small\", repo_id=\"adricl/midi_single_instrument_mistral_transformer\",\n",
669
+ " token=\"\")\n"
670
+ ]
671
+ },
672
+ {
673
+ "cell_type": "markdown",
674
+ "metadata": {},
675
+ "source": [
676
+ "# For Tensorboard tensorboard --logdir runs/"
677
+ ]
678
+ },
679
+ {
680
+ "cell_type": "code",
681
+ "execution_count": null,
682
+ "metadata": {},
683
+ "outputs": [],
684
+ "source": [
685
+ "config = AutoConfig.from_pretrained(str(model_dir / \"config.json\"))\n",
686
+ "model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=str(model_dir / \"model.safetensors\"), from_tf=False, config=config)"
687
+ ]
688
+ },
689
+ {
690
+ "attachments": {},
691
+ "cell_type": "markdown",
692
+ "metadata": {},
693
+ "source": [
694
+ "## Generate music"
695
+ ]
696
+ },
697
+ {
698
+ "cell_type": "code",
699
+ "execution_count": null,
700
+ "metadata": {
701
+ "cellView": "form",
702
+ "id": "OaNkGcFo9UP_"
703
+ },
704
+ "outputs": [],
705
+ "source": [
706
+ "# for single track midi files splits \n",
707
+ "\n",
708
+ "gen_results_path = root_save / 'gen_res'\n",
709
+ "gen_results_path.mkdir(parents=True, exist_ok=True)\n",
710
+ "generation_config = GenerationConfig(\n",
711
+ " max_new_tokens=200, # extends samples by 200 tokens\n",
712
+ " num_beams=1, # no beam search\n",
713
+ " do_sample=True, # but sample instead\n",
714
+ " temperature=0.9,\n",
715
+ " top_k=15,\n",
716
+ " top_p=0.95,\n",
717
+ " epsilon_cutoff=3e-4,\n",
718
+ " eta_cutoff=1e-3,\n",
719
+ " pad_token_id=tokenizer.pad_token_id,\n",
720
+ ")\n",
721
+ "\n",
722
+ "# Here the sequences are padded to the left, so that the last token along the time dimension\n",
723
+ "# is always the last token of each seq, allowing to efficiently generate by batch\n",
724
+ "collator.pad_on_left = True\n",
725
+ "collator.eos_token = None\n",
726
+ "dataloader_test = DataLoader(dataset_preview, batch_size=24, collate_fn=collator)\n",
727
+ "model.eval()\n",
728
+ "count = 0\n",
729
+ "for batch in tqdm(dataloader_test, desc='Testing model / Generating results'): # (N,T)\n",
730
+ " print(batch)\n",
731
+ " res = model.generate(\n",
732
+ " inputs=batch[\"input_ids\"].to(model.device),\n",
733
+ " attention_mask=batch[\"attention_mask\"].to(model.device),\n",
734
+ " generation_config=generation_config) # (N,T)\n",
735
+ "\n",
736
+ "\n",
737
+ " # Saves the generated music, as MIDI files and tokens (json)\n",
738
+ " for prompt, continuation in zip(batch[\"input_ids\"], res):\n",
739
+ " generated = continuation[len(prompt):]\n",
740
+ " midi = tokenizer.decode([deepcopy(generated.tolist())])\n",
741
+ " tokens = [generated, prompt, continuation] # list compr. as seqs of dif. lengths\n",
742
+ " tokens = [seq.tolist() for seq in tokens]\n",
743
+ " for tok_seq in tokens[1:]:\n",
744
+ " _midi = tokenizer.decode([deepcopy(tok_seq)])\n",
745
+ " midi.tracks.append(_midi.tracks[0])\n",
746
+ " \n",
747
+ " file_name = file_name_lookup[count]\n",
748
+ " print(file_name)\n",
749
+ " midi.tracks[0].name = f'Continuation of original sample ({len(generated)} tokens) Original file {file_name}'\n",
750
+ " midi.tracks[1].name = f'Original sample ({len(prompt)} tokens)'\n",
751
+ " if (len(midi.tracks) > 2):\n",
752
+ " midi.tracks[2].name = f'Original sample and continuation'\n",
753
+ " midi.dump_midi(gen_results_path / f'{count}_{file_name}.mid')\n",
754
+ " tokenizer.save_tokens(tokens, gen_results_path / f'{count}_{file_name}.json') \n",
755
+ "\n",
756
+ " count += 1"
757
+ ]
758
+ },
759
+ {
760
+ "cell_type": "code",
761
+ "execution_count": null,
762
+ "metadata": {},
763
+ "outputs": [],
764
+ "source": [
765
+ "print(file_name_lookup)"
766
+ ]
767
+ }
768
+ ],
769
+ "metadata": {
770
+ "accelerator": "GPU",
771
+ "colab": {
772
+ "collapsed_sections": [],
773
+ "machine_shape": "hm",
774
+ "name": "Optimus_VIRTUOSO_Multi_Instrumental_RGA_Edition.ipynb",
775
+ "private_outputs": true,
776
+ "provenance": []
777
+ },
778
+ "kernelspec": {
779
+ "display_name": "Python 3 (ipykernel)",
780
+ "language": "python",
781
+ "name": "python3"
782
+ },
783
+ "language_info": {
784
+ "codemirror_mode": {
785
+ "name": "ipython",
786
+ "version": 3
787
+ },
788
+ "file_extension": ".py",
789
+ "mimetype": "text/x-python",
790
+ "name": "python",
791
+ "nbconvert_exporter": "python",
792
+ "pygments_lexer": "ipython3",
793
+ "version": "3.9.5"
794
+ },
795
+ "vscode": {
796
+ "interpreter": {
797
+ "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
798
+ }
799
+ }
800
+ },
801
+ "nbformat": 4,
802
+ "nbformat_minor": 4
803
+ }
HuggingFace_Mistral_Transformer_Single_Instrument.ipynb CHANGED
@@ -37,32 +37,181 @@
37
  },
38
  {
39
  "cell_type": "code",
40
- "execution_count": null,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  "metadata": {
42
  "cellView": "form",
43
  "id": "fX12Yquyuihc"
44
  },
45
  "outputs": [],
46
  "source": [
47
- "\n",
48
- "\n",
49
  "from copy import deepcopy\n",
50
  "from pathlib import Path\n",
51
- "from random import shuffle\n",
52
  "\n",
53
  "from evaluate import load as load_metric\n",
54
- "from miditok import REMI, TokenizerConfig, TokTrainingIterator\n",
55
  "from miditok.pytorch_data import DatasetMIDI, DataCollator\n",
56
  "from miditok.utils import split_files_for_training\n",
57
  "\n",
58
  "from miditok.data_augmentation import augment_dataset\n",
59
- "from torch import Tensor, argmax\n",
60
  "from torch.utils.data import DataLoader\n",
61
  "from torch.cuda import is_available as cuda_available, is_bf16_supported\n",
62
  "from torch.backends.mps import is_available as mps_available\n",
63
  "from transformers import AutoModelForCausalLM, MistralConfig, Trainer, TrainingArguments, GenerationConfig, AutoConfig\n",
64
  "from transformers.trainer_utils import set_seed\n",
65
- "from tqdm import tqdm"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  ]
67
  },
68
  {
@@ -75,13 +224,10 @@
75
  },
76
  {
77
  "cell_type": "code",
78
- "execution_count": null,
79
  "metadata": {},
80
  "outputs": [],
81
  "source": [
82
- "# Seed\n",
83
- "set_seed(777)\n",
84
- "\n",
85
  "# Our tokenizer's configuration\n",
86
  "BEAT_RES = {(0, 1): 12, (1, 2): 4, (2, 4): 2, (4, 8): 1}\n",
87
  "TOKENIZER_PARAMS = {\n",
@@ -114,14 +260,16 @@
114
  },
115
  {
116
  "cell_type": "code",
117
- "execution_count": null,
118
  "metadata": {},
119
  "outputs": [],
120
  "source": [
121
- "root_data_dir = Path('/home/wombat/Documents/projects/music/midiTok/data/')\n",
122
  "root_save = Path(root_data_dir / 'HuggingFace_Mistral_Transformer_Single_Instrument')\n",
123
  "\n",
124
- "tokenizer_name = \"HuggingFace_Mistral_Transformer_Single_Instrument.json\""
 
 
125
  ]
126
  },
127
  {
@@ -132,13 +280,23 @@
132
  "source": [
133
  "\n",
134
  "# Trains the tokenizer with Byte Pair Encoding (BPE) to build the vocabulary, here 30k tokens\n",
135
- "data_dirs = [\"adl-piano-midi\", \"maestro-v3.0.0\", \"musicnet_midis\" ] # for single \n",
 
136
  "midi_paths = []\n",
137
  "for data_dir in data_dirs:\n",
138
  " path = Path(root_data_dir / 'Traning Data' / data_dir)\n",
139
  " midi_paths.extend(list(path.resolve().glob(\"**/*.mid\")) + list(path.resolve().glob(\"**/*.midi\")))\n",
140
  "\n",
141
- "print(f\"Found {len(midi_paths)} MIDI files\")"
 
 
 
 
 
 
 
 
 
142
  ]
143
  },
144
  {
@@ -149,8 +307,8 @@
149
  "source": [
150
  "#Note the size of the dataset is quite large, so it requires a huge amount of memory to train the tokenizer for 61749 files it took 64gb of memory\n",
151
  "tokenizer.train(\n",
152
- " vocab_size=32000,\n",
153
- " files_paths=midi_paths,\n",
154
  ")\n",
155
  "tokenizer.save(root_save / tokenizer_name)\n",
156
  "\n"
@@ -158,11 +316,11 @@
158
  },
159
  {
160
  "cell_type": "code",
161
- "execution_count": null,
162
  "metadata": {},
163
  "outputs": [],
164
  "source": [
165
- "tokenizer = REMI(params=Path(root_save / tokenizer_name))"
166
  ]
167
  },
168
  {
@@ -177,12 +335,19 @@
177
  },
178
  {
179
  "cell_type": "code",
180
- "execution_count": null,
181
  "metadata": {},
182
  "outputs": [],
183
  "source": [
184
  "sequence_length = 1024 # The maximum sequence length for data samples.\n",
185
- "kwargs_dataset = {\"max_seq_len\": sequence_length, \"tokenizer\": tokenizer, \"bos_token_id\": tokenizer[\"BOS_None\"], \"eos_token_id\": tokenizer[\"EOS_None\"]}"
 
 
 
 
 
 
 
186
  ]
187
  },
188
  {
@@ -191,30 +356,27 @@
191
  "metadata": {},
192
  "outputs": [],
193
  "source": [
194
- "def remove_files_with_boring_data( file_paths: list[Path], rms_threshold: float = 0.01) -> list[Path]:\n",
195
- " \"\"\"\n",
196
- " Remove files with boring data, i.e. files with low RMS.\n",
197
- " \"\"\"\n",
198
- " from symusic import Score\n",
199
- " from tqdm import tqdm\n",
200
- " import numpy as np\n",
201
- "\n",
202
- " rms = lambda data: (sum(x * x for x in data) / len(data)) ** 0.5\n",
203
- "\n",
204
- " filtered_files = []\n",
205
- " for file_path in tqdm(file_paths, desc=\"Filtering boring files\"):\n",
206
- " try:\n",
207
- " scores = [Score(file_path)]\n",
208
- " except SCORE_LOADING_EXCEPTION:\n",
209
- " continue\n",
210
- "\n",
211
- " for track in scores[0].tracks:\n",
212
- " values = track.notes['pitch']\n",
213
- " result = rms(values)\n",
214
- "\n",
215
  "\n",
216
- " filtered_files.append(file_path)\n",
217
- " return filtered_files"
 
 
 
 
 
 
218
  ]
219
  },
220
  {
@@ -225,6 +387,7 @@
225
  "source": [
226
  "# Split MIDI paths in train/valid/test sets\n",
227
  "total_num_files = len(midi_paths)\n",
 
228
  "num_files_valid = round(total_num_files * 0.15)\n",
229
  "num_files_test = round(total_num_files * 0.15)\n",
230
  "shuffle(midi_paths)\n",
@@ -248,6 +411,7 @@
248
  " save_dir=subset_chunks_dir,\n",
249
  " max_seq_len=sequence_length,\n",
250
  " num_overlap_bars=2,\n",
 
251
  " )\n",
252
  "\n",
253
  " if subset_name == 'train':\n",
@@ -261,6 +425,16 @@
261
  " )\n"
262
  ]
263
  },
 
 
 
 
 
 
 
 
 
 
264
  {
265
  "cell_type": "code",
266
  "execution_count": null,
@@ -269,47 +443,39 @@
269
  "source": [
270
  "# Create Dataset and Collator for training\n",
271
  "midi_paths_train = list(root_save.joinpath(Path(\"Maestro_train\")).glob(\"**/*.mid\")) + list(root_save.joinpath(Path(\"Maestro_train\")).glob(\"**/*.midi\"))\n",
272
- "midi_paths_valid = list(root_save.joinpath(Path(\"Maestro_valid\")).glob(\"**/*.mid\")) + list(root_save.joinpath(Path(\"Maestro_valid\")).glob(\"**/*.midi\")) \n",
273
- "midi_paths_test = list(root_save.joinpath(Path(\"Maestro_test\")).glob(\"**/*.mid\")) + list(root_save.joinpath(Path(\"Maestro_test\")).glob(\"**/*.midi\"))\n",
274
  "\n",
275
- "\n",
276
- "\n",
277
- "dataset_train = DatasetMIDI(midi_paths_train, **kwargs_dataset)\n",
278
- "dataset_valid = DatasetMIDI(midi_paths_valid, **kwargs_dataset)\n",
279
- "dataset_test = DatasetMIDI(midi_paths_test, **kwargs_dataset)\n",
280
- "print (len(midi_paths_train), len(midi_paths_valid), len(midi_paths_test))"
281
  ]
282
  },
283
  {
284
- "cell_type": "markdown",
 
285
  "metadata": {},
 
286
  "source": [
287
- "# Preview files data load and split"
 
 
 
 
288
  ]
289
  },
290
  {
291
  "cell_type": "code",
292
  "execution_count": null,
293
- "metadata": {
294
- "tags": [
295
- "Generate Preview Files"
296
- ]
297
- },
298
  "outputs": [],
299
  "source": [
300
- "#testing_files = \n",
301
- "preview_files_path = []\n",
302
- "for testing_file in testing_files:\n",
303
- " preview_files_path.append(Path(testing_file))\n",
304
- "\n",
305
- "preview_dir = Path(root_save / \"preview\")\n",
306
- "split_files_for_training(\n",
307
- " files_paths=preview_files_path,\n",
308
- " tokenizer=tokenizer,\n",
309
- " save_dir=preview_dir,\n",
310
- " max_seq_len=sequence_length,\n",
311
- " num_overlap_bars=2,\n",
312
- " )\n"
313
  ]
314
  },
315
  {
@@ -318,18 +484,7 @@
318
  "metadata": {},
319
  "outputs": [],
320
  "source": [
321
- "valid_midi_path = root_save / \"Maestro_valid\"\n",
322
- "midi_split_preview = list(valid_midi_path.resolve().glob(\"**/*.mid\")) + list(valid_midi_path.resolve().glob(\"**/*.midi\"))\n",
323
- "\n",
324
- "print(len(midi_split_preview))\n",
325
- "file_name_lookup = []\n",
326
- "def func_to_get_labels(p1, p2, p3):\n",
327
- " if p3.name not in file_name_lookup:\n",
328
- " file_name_lookup.append(p3.name)\n",
329
- " return file_name_lookup.index(p3.name)\n",
330
- " \n",
331
- "kwargs_dataset = {\"max_seq_len\": sequence_length, \"tokenizer\": tokenizer, \"bos_token_id\": tokenizer[\"BOS_None\"], \"eos_token_id\": tokenizer[\"EOS_None\"], \"func_to_get_labels\" : func_to_get_labels}\n",
332
- "dataset_preview = DatasetMIDI(midi_split_preview, **kwargs_dataset)"
333
  ]
334
  },
335
  {
@@ -341,12 +496,15 @@
341
  },
342
  {
343
  "cell_type": "code",
344
- "execution_count": null,
345
  "metadata": {},
346
  "outputs": [],
347
  "source": [
348
- "dataset_dir = root_save / \"data\"\n",
349
- "dataset_dir.mkdir(parents=True, exist_ok=True)"
 
 
 
350
  ]
351
  },
352
  {
@@ -355,23 +513,27 @@
355
  "metadata": {},
356
  "outputs": [],
357
  "source": [
358
- "import torch\n",
359
- "torch.save(dataset_train, Path(dataset_dir / \"dataset_train.pt\"))\n",
360
- "torch.save(dataset_valid, Path(dataset_dir / \"dataset_valid.pt\"))\n",
361
- "torch.save(dataset_test, Path(dataset_dir / \"dataset_test.pt\"))\n"
 
 
 
 
 
 
 
 
 
 
362
  ]
363
  },
364
  {
365
- "cell_type": "code",
366
- "execution_count": null,
367
  "metadata": {},
368
- "outputs": [],
369
  "source": [
370
- "import torch\n",
371
- "dataset_train = torch.load(Path(dataset_dir / \"dataset_train.pt\"))\n",
372
- "dataset_valid = torch.load(Path(dataset_dir / \"dataset_valid.pt\"))\n",
373
- "dataset_test = torch.load(Path(dataset_dir / \"dataset_test.pt\"))\n",
374
- "\n"
375
  ]
376
  },
377
  {
@@ -380,7 +542,33 @@
380
  "metadata": {},
381
  "outputs": [],
382
  "source": [
383
- "print(dataset_train[0])\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  ]
385
  },
386
  {
@@ -398,9 +586,22 @@
398
  },
399
  {
400
  "cell_type": "code",
401
- "execution_count": null,
402
  "metadata": {},
403
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  "source": [
405
  "# Creates model\n",
406
  "model_config = MistralConfig(\n",
@@ -411,12 +612,80 @@
411
  " num_attention_heads=8, # default 32\n",
412
  " num_key_value_heads=4, # default 8\n",
413
  " sliding_window=256, # default 4096\n",
414
- " max_position_embeddings=sequence_length + 256, # 8192 this was before # default 4096*32\n",
415
  " pad_token_id=tokenizer['PAD_None'],\n",
416
  " bos_token_id=tokenizer['BOS_None'],\n",
417
  " eos_token_id=tokenizer['EOS_None'],\n",
418
  ")\n",
419
- "model = AutoModelForCausalLM.from_config(model_config)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  ]
421
  },
422
  {
@@ -429,9 +698,17 @@
429
  },
430
  {
431
  "cell_type": "code",
432
- "execution_count": null,
433
  "metadata": {},
434
- "outputs": [],
 
 
 
 
 
 
 
 
435
  "source": [
436
  "model_dir = root_save / 'run'\n",
437
  "model_dir_str = str(model_dir)\n",
@@ -442,7 +719,25 @@
442
  "cell_type": "code",
443
  "execution_count": null,
444
  "metadata": {},
445
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446
  "source": [
447
  "metrics = {metric: load_metric(metric) for metric in [\"accuracy\"]}\n",
448
  "\n",
@@ -483,22 +778,23 @@
483
  "USE_MPS = not USE_CUDA and mps_available()\n",
484
  "training_config = TrainingArguments(\n",
485
  " model_dir_str, False, True, True, False, \"steps\",\n",
486
- " per_device_train_batch_size=30, #76% @ 24 batch size #76% @ 32 batch size try 64 batch size next time \n",
487
- " per_device_eval_batch_size=30, #was 24 now 32\n",
488
  " gradient_accumulation_steps=3, #change this to 4\n",
489
  " eval_accumulation_steps=None,\n",
490
- " eval_steps=1000,\n",
 
491
  " learning_rate=1e-4,\n",
492
  " weight_decay=0.01,\n",
493
- " max_grad_norm=3.0,\n",
494
- " max_steps=40000,\n",
495
- " lr_scheduler_type=\"cosine_with_restarts\",\n",
496
- " warmup_ratio=0.3,\n",
497
  " log_level=\"debug\",\n",
498
  " logging_strategy=\"steps\",\n",
499
- " logging_steps=20,\n",
500
  " save_strategy=\"steps\",\n",
501
- " save_steps=1000,\n",
502
  " save_total_limit=5,\n",
503
  " no_cuda=not USE_CUDA,\n",
504
  " seed=444,\n",
@@ -509,11 +805,11 @@
509
  " load_best_model_at_end=True,\n",
510
  " label_smoothing_factor=0.,\n",
511
  " optim=\"adamw_torch\",\n",
512
- " report_to=[\"tensorboard\"],\n",
513
- " gradient_checkpointing=True,\n",
514
  " dataloader_num_workers=8, #added to fix trashing isssue with the gpu not having enough data to process\n",
515
  " dataloader_pin_memory=True, #we want the dataset in memory\n",
516
- " torch_compile=True #added to speed up \n",
517
  " \n",
518
  ")\n",
519
  "\n",
@@ -538,7 +834,18 @@
538
  "metadata": {},
539
  "outputs": [],
540
  "source": [
541
- "print(model)"
 
 
 
 
 
 
 
 
 
 
 
542
  ]
543
  },
544
  {
@@ -546,6 +853,159 @@
546
  "execution_count": null,
547
  "metadata": {},
548
  "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
549
  "source": [
550
  "# Training\n",
551
  "train_result = trainer.train()\n",
@@ -555,6 +1015,62 @@
555
  "trainer.save_state()"
556
  ]
557
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
558
  {
559
  "cell_type": "code",
560
  "execution_count": null,
@@ -567,9 +1083,32 @@
567
  },
568
  {
569
  "cell_type": "code",
570
- "execution_count": null,
571
  "metadata": {},
572
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
573
  "source": [
574
  "\n",
575
  "model.hub_model_id = \"adricl/midi_single_instrument_mistral_transformer\"\n",
@@ -699,7 +1238,7 @@
699
  "name": "python",
700
  "nbconvert_exporter": "python",
701
  "pygments_lexer": "ipython3",
702
- "version": "3.9.5"
703
  },
704
  "vscode": {
705
  "interpreter": {
 
37
  },
38
  {
39
  "cell_type": "code",
40
+ "execution_count": 1,
41
+ "metadata": {},
42
+ "outputs": [
43
+ {
44
+ "name": "stdout",
45
+ "output_type": "stream",
46
+ "text": [
47
+ "Requirement already satisfied: pip in /usr/local/lib/python3.11/dist-packages (25.2)\n"
48
+ ]
49
+ },
50
+ {
51
+ "name": "stderr",
52
+ "output_type": "stream",
53
+ "text": [
54
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n",
55
+ "\u001b[0m"
56
+ ]
57
+ }
58
+ ],
59
+ "source": [
60
+ "%%python -m pip install --upgrade pip\n",
61
+ "\n"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": 2,
67
+ "metadata": {},
68
+ "outputs": [
69
+ {
70
+ "name": "stdout",
71
+ "output_type": "stream",
72
+ "text": [
73
+ "Requirement already satisfied: evaluate in /usr/local/lib/python3.11/dist-packages (0.4.6)\n",
74
+ "Requirement already satisfied: transformers in /usr/local/lib/python3.11/dist-packages (4.56.2)\n",
75
+ "Requirement already satisfied: tqdm in /usr/local/lib/python3.11/dist-packages (4.67.1)\n",
76
+ "Requirement already satisfied: miditok in /usr/local/lib/python3.11/dist-packages (3.0.6.post1)\n",
77
+ "Requirement already satisfied: accelerate in /usr/local/lib/python3.11/dist-packages (1.10.1)\n",
78
+ "Requirement already satisfied: tensorboardX in /usr/local/lib/python3.11/dist-packages (2.6.4)\n",
79
+ "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.11/dist-packages (1.7.2)\n",
80
+ "Requirement already satisfied: wandb in /usr/local/lib/python3.11/dist-packages (0.22.0)\n",
81
+ "Requirement already satisfied: datasets>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from evaluate) (4.1.1)\n",
82
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.11/dist-packages (from evaluate) (2.1.2)\n",
83
+ "Requirement already satisfied: dill in /usr/local/lib/python3.11/dist-packages (from evaluate) (0.4.0)\n",
84
+ "Requirement already satisfied: pandas in /usr/local/lib/python3.11/dist-packages (from evaluate) (2.3.2)\n",
85
+ "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.11/dist-packages (from evaluate) (2.32.3)\n",
86
+ "Requirement already satisfied: xxhash in /usr/local/lib/python3.11/dist-packages (from evaluate) (3.5.0)\n",
87
+ "Requirement already satisfied: multiprocess in /usr/local/lib/python3.11/dist-packages (from evaluate) (0.70.16)\n",
88
+ "Requirement already satisfied: fsspec>=2021.05.0 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]>=2021.05.0->evaluate) (2024.10.0)\n",
89
+ "Requirement already satisfied: huggingface-hub>=0.7.0 in /usr/local/lib/python3.11/dist-packages (from evaluate) (0.35.1)\n",
90
+ "Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from evaluate) (24.2)\n",
91
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from transformers) (3.16.1)\n",
92
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.11/dist-packages (from transformers) (6.0.2)\n",
93
+ "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers) (2025.9.18)\n",
94
+ "Requirement already satisfied: tokenizers<=0.23.0,>=0.22.0 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.22.1)\n",
95
+ "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.6.2)\n",
96
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.7.0->evaluate) (4.12.2)\n",
97
+ "Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.7.0->evaluate) (1.1.10)\n",
98
+ "Requirement already satisfied: symusic>=0.5.0 in /usr/local/lib/python3.11/dist-packages (from miditok) (0.5.8)\n",
99
+ "Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from accelerate) (7.0.0)\n",
100
+ "Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from accelerate) (2.8.0.dev20250319+cu128)\n",
101
+ "Requirement already satisfied: protobuf>=3.20 in /usr/local/lib/python3.11/dist-packages (from tensorboardX) (6.32.1)\n",
102
+ "Requirement already satisfied: scipy>=1.8.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn) (1.16.2)\n",
103
+ "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn) (1.5.2)\n",
104
+ "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn) (3.6.0)\n",
105
+ "Requirement already satisfied: click>=8.0.1 in /usr/local/lib/python3.11/dist-packages (from wandb) (8.3.0)\n",
106
+ "Requirement already satisfied: gitpython!=3.1.29,>=1.0.0 in /usr/local/lib/python3.11/dist-packages (from wandb) (3.1.45)\n",
107
+ "Requirement already satisfied: platformdirs in /usr/local/lib/python3.11/dist-packages (from wandb) (4.3.7)\n",
108
+ "Requirement already satisfied: pydantic<3 in /usr/local/lib/python3.11/dist-packages (from wandb) (2.11.9)\n",
109
+ "Requirement already satisfied: sentry-sdk>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from wandb) (2.39.0)\n",
110
+ "Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from pydantic<3->wandb) (0.7.0)\n",
111
+ "Requirement already satisfied: pydantic-core==2.33.2 in /usr/local/lib/python3.11/dist-packages (from pydantic<3->wandb) (2.33.2)\n",
112
+ "Requirement already satisfied: typing-inspection>=0.4.0 in /usr/local/lib/python3.11/dist-packages (from pydantic<3->wandb) (0.4.1)\n",
113
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests>=2.19.0->evaluate) (3.4.1)\n",
114
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests>=2.19.0->evaluate) (3.10)\n",
115
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests>=2.19.0->evaluate) (2.3.0)\n",
116
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests>=2.19.0->evaluate) (2025.1.31)\n",
117
+ "Requirement already satisfied: pyarrow>=21.0.0 in /usr/local/lib/python3.11/dist-packages (from datasets>=2.0.0->evaluate) (21.0.0)\n",
118
+ "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]>=2021.05.0->evaluate) (3.12.15)\n",
119
+ "Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (2.6.1)\n",
120
+ "Requirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (1.4.0)\n",
121
+ "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (25.3.0)\n",
122
+ "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (1.7.0)\n",
123
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (6.6.4)\n",
124
+ "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (0.3.2)\n",
125
+ "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (1.20.1)\n",
126
+ "Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.11/dist-packages (from gitpython!=3.1.29,>=1.0.0->wandb) (4.0.12)\n",
127
+ "Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.11/dist-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb) (5.0.2)\n",
128
+ "Requirement already satisfied: pySmartDL in /usr/local/lib/python3.11/dist-packages (from symusic>=0.5.0->miditok) (1.3.4)\n",
129
+ "Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (1.13.3)\n",
130
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (3.4.2)\n",
131
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (3.1.4)\n",
132
+ "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.8.61 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (12.8.61)\n",
133
+ "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.8.57 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (12.8.57)\n",
134
+ "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.8.57 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (12.8.57)\n",
135
+ "Requirement already satisfied: nvidia-cudnn-cu12==9.8.0.87 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (9.8.0.87)\n",
136
+ "Requirement already satisfied: nvidia-cublas-cu12==12.8.3.14 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (12.8.3.14)\n",
137
+ "Requirement already satisfied: nvidia-cufft-cu12==11.3.3.41 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (11.3.3.41)\n",
138
+ "Requirement already satisfied: nvidia-curand-cu12==10.3.9.55 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (10.3.9.55)\n",
139
+ "Requirement already satisfied: nvidia-cusolver-cu12==11.7.2.55 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (11.7.2.55)\n",
140
+ "Requirement already satisfied: nvidia-cusparse-cu12==12.5.7.53 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (12.5.7.53)\n",
141
+ "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.3 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (0.6.3)\n",
142
+ "Requirement already satisfied: nvidia-nccl-cu12==2.25.1 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (2.25.1)\n",
143
+ "Requirement already satisfied: nvidia-nvtx-cu12==12.8.55 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (12.8.55)\n",
144
+ "Requirement already satisfied: nvidia-nvjitlink-cu12==12.8.61 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (12.8.61)\n",
145
+ "Requirement already satisfied: nvidia-cufile-cu12==1.13.0.11 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (1.13.0.11)\n",
146
+ "Requirement already satisfied: pytorch-triton==3.3.0+git96316ce5 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (3.3.0+git96316ce5)\n",
147
+ "Requirement already satisfied: setuptools>=40.8.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-triton==3.3.0+git96316ce5->torch>=2.0.0->accelerate) (77.0.1)\n",
148
+ "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy>=1.13.3->torch>=2.0.0->accelerate) (1.3.0)\n",
149
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch>=2.0.0->accelerate) (2.1.5)\n",
150
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas->evaluate) (2.9.0.post0)\n",
151
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas->evaluate) (2025.2)\n",
152
+ "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas->evaluate) (2025.2)\n",
153
+ "Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas->evaluate) (1.16.0)\n",
154
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n",
155
+ "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n"
156
+ ]
157
+ }
158
+ ],
159
+ "source": [
160
+ "%pip install evaluate transformers tqdm miditok accelerate tensorboardX scikit-learn wandb"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "code",
165
+ "execution_count": 3,
166
  "metadata": {
167
  "cellView": "form",
168
  "id": "fX12Yquyuihc"
169
  },
170
  "outputs": [],
171
  "source": [
 
 
172
  "from copy import deepcopy\n",
173
  "from pathlib import Path\n",
174
+ "from random import shuffle, sample\n",
175
  "\n",
176
  "from evaluate import load as load_metric\n",
177
+ "from miditok import REMI, TokenizerConfig\n",
178
  "from miditok.pytorch_data import DatasetMIDI, DataCollator\n",
179
  "from miditok.utils import split_files_for_training\n",
180
  "\n",
181
  "from miditok.data_augmentation import augment_dataset\n",
182
+ "from torch import Tensor, argmax, torch\n",
183
  "from torch.utils.data import DataLoader\n",
184
  "from torch.cuda import is_available as cuda_available, is_bf16_supported\n",
185
  "from torch.backends.mps import is_available as mps_available\n",
186
  "from transformers import AutoModelForCausalLM, MistralConfig, Trainer, TrainingArguments, GenerationConfig, AutoConfig\n",
187
  "from transformers.trainer_utils import set_seed\n",
188
+ "from tqdm import tqdm\n",
189
+ "\n",
190
+ "#Seed\n",
191
+ "set_seed(777)"
192
+ ]
193
+ },
194
+ {
195
+ "cell_type": "code",
196
+ "execution_count": 4,
197
+ "metadata": {},
198
+ "outputs": [
199
+ {
200
+ "name": "stderr",
201
+ "output_type": "stream",
202
+ "text": [
203
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33madric-landman\u001b[0m (\u001b[33madric-landman-hobby\u001b[0m) to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
204
+ ]
205
+ }
206
+ ],
207
+ "source": [
208
+ "# Wand db\n",
209
+ "import wandb\n",
210
+ "import os\n",
211
+ "wandb.login() #6bfb401d368c508cb01a291a8ae84c0ecce2310d \n",
212
+ "\n",
213
+ "os.environ[\"WANDB_PROJECT\"]=\"midi_music_maker\"\n",
214
+ "\n"
215
  ]
216
  },
217
  {
 
224
  },
225
  {
226
  "cell_type": "code",
227
+ "execution_count": 6,
228
  "metadata": {},
229
  "outputs": [],
230
  "source": [
 
 
 
231
  "# Our tokenizer's configuration\n",
232
  "BEAT_RES = {(0, 1): 12, (1, 2): 4, (2, 4): 2, (4, 8): 1}\n",
233
  "TOKENIZER_PARAMS = {\n",
 
260
  },
261
  {
262
  "cell_type": "code",
263
+ "execution_count": 7,
264
  "metadata": {},
265
  "outputs": [],
266
  "source": [
267
+ "root_data_dir = Path('/workspace/traindata/data')\n",
268
  "root_save = Path(root_data_dir / 'HuggingFace_Mistral_Transformer_Single_Instrument')\n",
269
  "\n",
270
+ "tokenizer_name = \"HuggingFace_Mistral_Transformer_Single_Instrument_v4_single_track.json\"\n",
271
+ "dataset_dir = root_save / \"data\"\n",
272
+ "dataset_dir.mkdir(parents=True, exist_ok=True)"
273
  ]
274
  },
275
  {
 
280
  "source": [
281
  "\n",
282
  "# Trains the tokenizer with Byte Pair Encoding (BPE) to build the vocabulary, here 30k tokens\n",
283
+ "#data_dirs = [\"adl-piano-midi\", \"maestro-v3.0.0\", \"musicnet_midis\" ] # for single \n",
284
+ "data_dirs = [\"MIDIs\"]\n",
285
  "midi_paths = []\n",
286
  "for data_dir in data_dirs:\n",
287
  " path = Path(root_data_dir / 'Traning Data' / data_dir)\n",
288
  " midi_paths.extend(list(path.resolve().glob(\"**/*.mid\")) + list(path.resolve().glob(\"**/*.midi\")))\n",
289
  "\n",
290
+ "print(f\"Found {len(midi_paths)} MIDI files\")\n",
291
+ "\n",
292
+ "shuffle(midi_paths)\n",
293
+ "\n",
294
+ "# We need a subset of files otherwise training tokenizer takes too long\n",
295
+ "percentage_to_select = 0.15\n",
296
+ "num_files_to_select = int(len(midi_paths) * percentage_to_select)\n",
297
+ "\n",
298
+ "subset_midi_paths = sample(midi_paths, num_files_to_select)\n",
299
+ "print(f\"Found {len(subset_midi_paths)} MIDI files\")"
300
  ]
301
  },
302
  {
 
307
  "source": [
308
  "#Note the size of the dataset is quite large, so it requires a huge amount of memory to train the tokenizer for 61749 files it took 64gb of memory\n",
309
  "tokenizer.train(\n",
310
+ " vocab_size=24000,\n",
311
+ " files_paths=subset_midi_paths,\n",
312
  ")\n",
313
  "tokenizer.save(root_save / tokenizer_name)\n",
314
  "\n"
 
316
  },
317
  {
318
  "cell_type": "code",
319
+ "execution_count": 8,
320
  "metadata": {},
321
  "outputs": [],
322
  "source": [
323
+ "tokenizer = REMI(params=Path(root_save / tokenizer_name))\n"
324
  ]
325
  },
326
  {
 
335
  },
336
  {
337
  "cell_type": "code",
338
+ "execution_count": 9,
339
  "metadata": {},
340
  "outputs": [],
341
  "source": [
342
  "sequence_length = 1024 # The maximum sequence length for data samples.\n",
343
+ "kwargs_dataset = {\"max_seq_len\": sequence_length, \"tokenizer\": tokenizer, \"bos_token_id\": tokenizer[\"BOS_None\"], \"eos_token_id\": tokenizer[\"EOS_None\"], \"pre_tokenize\": True, \"pre_tokenize_thread_count\": 7}"
344
+ ]
345
+ },
346
+ {
347
+ "cell_type": "markdown",
348
+ "metadata": {},
349
+ "source": [
350
+ "# Test splitting files for training and testing purposes"
351
  ]
352
  },
353
  {
 
356
  "metadata": {},
357
  "outputs": [],
358
  "source": [
359
+ "from pathlib import Path\n",
360
+ "# Split will need to add the BPM to the files its split\n",
361
+ "# \n",
362
+ "file_paths_test = [\n",
363
+ " Path('/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Fatboy Slim/Right Here, Right Now.mid'),\n",
364
+ " Path('/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Fatboy Slim/Praise You.mid'),\n",
365
+ " Path('/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Goo Goo Dolls/Iris.mid'),\n",
366
+ " Path('/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Goo Goo Dolls/Slide.mid'),\n",
367
+ " Path('/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/James Brown/Sex Machine (Get Up I Feel Like Being A).mid'),\n",
368
+ " Path('/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Jamiroquai/Virtual Insanity.1.mid'),\n",
369
+ " Path('/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Jamiroquai/Virtual Insanity.mid')\n",
370
+ "]\n",
 
 
 
 
 
 
 
 
 
371
  "\n",
372
+ "split_files_for_training(\n",
373
+ " files_paths=file_paths_test,\n",
374
+ " tokenizer=tokenizer,\n",
375
+ " save_dir=Path('/home/wombat/Documents/projects/music/midiTok/data/HuggingFace_Mistral_Transformer_Single_Instrument/test'),\n",
376
+ " max_seq_len=sequence_length,\n",
377
+ " num_overlap_bars=2,\n",
378
+ " skip_drums=True\n",
379
+ ")"
380
  ]
381
  },
382
  {
 
387
  "source": [
388
  "# Split MIDI paths in train/valid/test sets\n",
389
  "total_num_files = len(midi_paths)\n",
390
+ "\n",
391
  "num_files_valid = round(total_num_files * 0.15)\n",
392
  "num_files_test = round(total_num_files * 0.15)\n",
393
  "shuffle(midi_paths)\n",
 
411
  " save_dir=subset_chunks_dir,\n",
412
  " max_seq_len=sequence_length,\n",
413
  " num_overlap_bars=2,\n",
414
+ " skip_drums=True\n",
415
  " )\n",
416
  "\n",
417
  " if subset_name == 'train':\n",
 
425
  " )\n"
426
  ]
427
  },
428
+ {
429
+ "cell_type": "code",
430
+ "execution_count": null,
431
+ "metadata": {},
432
+ "outputs": [],
433
+ "source": [
434
+ "#Since the datasets are too large after splitting we only want 50% of the split data to train against\n",
435
+ "sample_subset_per = .25"
436
+ ]
437
+ },
438
  {
439
  "cell_type": "code",
440
  "execution_count": null,
 
443
  "source": [
444
  "# Create Dataset and Collator for training\n",
445
  "midi_paths_train = list(root_save.joinpath(Path(\"Maestro_train\")).glob(\"**/*.mid\")) + list(root_save.joinpath(Path(\"Maestro_train\")).glob(\"**/*.midi\"))\n",
446
+ "sample_count = round(len(midi_paths_train)*sample_subset_per)\n",
447
+ "print(f\"sample count length: {sample_count} total count: {len(midi_paths_train)}\")\n",
448
  "\n",
449
+ "midi_paths_train_sample = midi_paths_train[0:sample_count]\n",
450
+ "print(len(midi_paths_train_sample))\n",
451
+ "dataset_train = DatasetMIDI(midi_paths_train_sample, **kwargs_dataset)\n",
452
+ "torch.save(dataset_train, Path(dataset_dir / \"dataset_train.pt\"))"
 
 
453
  ]
454
  },
455
  {
456
+ "cell_type": "code",
457
+ "execution_count": null,
458
  "metadata": {},
459
+ "outputs": [],
460
  "source": [
461
+ "midi_paths_valid = list(root_save.joinpath(Path(\"Maestro_valid\")).glob(\"**/*.mid\")) + list(root_save.joinpath(Path(\"Maestro_valid\")).glob(\"**/*.midi\")) \n",
462
+ "midi_paths_valid = midi_paths_valid[:(len(midi_paths_valid)*sample_subset_per)]\n",
463
+ "print(len(midi_paths_valid))\n",
464
+ "dataset_valid = DatasetMIDI(midi_paths_valid, **kwargs_dataset)\n",
465
+ "torch.save(dataset_valid, Path(dataset_dir / \"dataset_valid.pt\"))"
466
  ]
467
  },
468
  {
469
  "cell_type": "code",
470
  "execution_count": null,
471
+ "metadata": {},
 
 
 
 
472
  "outputs": [],
473
  "source": [
474
+ "midi_paths_test = list(root_save.joinpath(Path(\"Maestro_test\")).glob(\"**/*.mid\")) + list(root_save.joinpath(Path(\"Maestro_test\")).glob(\"**/*.midi\"))\n",
475
+ "midi_paths_test = midi_paths_test[:(len(midi_paths_test)*sample_subset_per)]\n",
476
+ "print(len(midi_paths_test))\n",
477
+ "dataset_test = DatasetMIDI(midi_paths_test, **kwargs_dataset)\n",
478
+ "torch.save(dataset_test, Path(dataset_dir / \"dataset_test.pt\"))\n"
 
 
 
 
 
 
 
 
479
  ]
480
  },
481
  {
 
484
  "metadata": {},
485
  "outputs": [],
486
  "source": [
487
+ "print (len(midi_paths_train), len(midi_paths_valid), len(midi_paths_test))\n"
 
 
 
 
 
 
 
 
 
 
 
488
  ]
489
  },
490
  {
 
496
  },
497
  {
498
  "cell_type": "code",
499
+ "execution_count": 10,
500
  "metadata": {},
501
  "outputs": [],
502
  "source": [
503
+ "\n",
504
+ "dataset_train = torch.load(Path(dataset_dir / \"dataset_train.pt\"), weights_only=False)\n",
505
+ "dataset_valid = torch.load(Path(dataset_dir / \"dataset_valid.pt\"), weights_only=False)\n",
506
+ "dataset_test = torch.load(Path(dataset_dir / \"dataset_test.pt\"), weights_only=False)\n",
507
+ "\n"
508
  ]
509
  },
510
  {
 
513
  "metadata": {},
514
  "outputs": [],
515
  "source": [
516
+ "import pickle\n",
517
+ "\n",
518
+ "test_file = open(Path(dataset_dir / \"dataset_test.pickle\"), 'ab')\n",
519
+ "pickle.dump(dataset_test, test_file)\n",
520
+ "test_file.close()\n",
521
+ "\n",
522
+ "print(dataset_test[0])\n",
523
+ "\n",
524
+ "test_file = open(Path(dataset_dir / \"dataset_test.pickle\"), 'rb')\n",
525
+ "test_pickle = pickle.load(test_file)\n",
526
+ "print(test_pickle)\n",
527
+ "print(test_pickle[0])\n",
528
+ "\n",
529
+ "\n"
530
  ]
531
  },
532
  {
533
+ "cell_type": "markdown",
 
534
  "metadata": {},
 
535
  "source": [
536
+ "# Preview files data load and split"
 
 
 
 
537
  ]
538
  },
539
  {
 
542
  "metadata": {},
543
  "outputs": [],
544
  "source": [
545
+ "\n",
546
+ "#testing_files = \n",
547
+ "preview_files_path = []\n",
548
+ "for testing_file in testing_files:\n",
549
+ " preview_files_path.append(Path(testing_file))\n",
550
+ "\n",
551
+ "preview_dir = Path(root_save / \"preview\")\n",
552
+ "split_files_for_training(\n",
553
+ " files_paths=preview_files_path,\n",
554
+ " tokenizer=tokenizer,\n",
555
+ " save_dir=preview_dir,\n",
556
+ " max_seq_len=sequence_length,\n",
557
+ " num_overlap_bars=2,\n",
558
+ " )\n",
559
+ "\n",
560
+ "valid_midi_path = root_save / \"Maestro_valid\"\n",
561
+ "midi_split_preview = list(valid_midi_path.resolve().glob(\"**/*.mid\")) + list(valid_midi_path.resolve().glob(\"**/*.midi\"))\n",
562
+ "\n",
563
+ "print(len(midi_split_preview))\n",
564
+ "file_name_lookup = []\n",
565
+ "def func_to_get_labels(p1, p2, p3):\n",
566
+ " if p3.name not in file_name_lookup:\n",
567
+ " file_name_lookup.append(p3.name)\n",
568
+ " return file_name_lookup.index(p3.name)\n",
569
+ " \n",
570
+ "kwargs_dataset = {\"max_seq_len\": sequence_length, \"tokenizer\": tokenizer, \"bos_token_id\": tokenizer[\"BOS_None\"], \"eos_token_id\": tokenizer[\"EOS_None\"], \"func_to_get_labels\" : func_to_get_labels}\n",
571
+ "dataset_preview = DatasetMIDI(midi_split_preview, **kwargs_dataset)"
572
  ]
573
  },
574
  {
 
586
  },
587
  {
588
  "cell_type": "code",
589
+ "execution_count": 19,
590
  "metadata": {},
591
+ "outputs": [
592
+ {
593
+ "name": "stderr",
594
+ "output_type": "stream",
595
+ "text": [
596
+ "Generate config GenerationConfig {\n",
597
+ " \"bos_token_id\": 1,\n",
598
+ " \"eos_token_id\": 2,\n",
599
+ " \"pad_token_id\": 0\n",
600
+ "}\n",
601
+ "\n"
602
+ ]
603
+ }
604
+ ],
605
  "source": [
606
  "# Creates model\n",
607
  "model_config = MistralConfig(\n",
 
612
  " num_attention_heads=8, # default 32\n",
613
  " num_key_value_heads=4, # default 8\n",
614
  " sliding_window=256, # default 4096\n",
615
+ " max_position_embeddings=8192, #has no effect on the parms count or training just limits the input length # default 4096*32\n",
616
  " pad_token_id=tokenizer['PAD_None'],\n",
617
  " bos_token_id=tokenizer['BOS_None'],\n",
618
  " eos_token_id=tokenizer['EOS_None'],\n",
619
  ")\n",
620
+ "\n",
621
+ "model = AutoModelForCausalLM.from_config(model_config)\n"
622
+ ]
623
+ },
624
+ {
625
+ "cell_type": "code",
626
+ "execution_count": 22,
627
+ "metadata": {},
628
+ "outputs": [
629
+ {
630
+ "name": "stderr",
631
+ "output_type": "stream",
632
+ "text": [
633
+ "loading configuration file /workspace/traindata/train/checkpoint-22000/config.json\n"
634
+ ]
635
+ },
636
+ {
637
+ "name": "stderr",
638
+ "output_type": "stream",
639
+ "text": [
640
+ "Model config MistralConfig {\n",
641
+ " \"architectures\": [\n",
642
+ " \"MistralForCausalLM\"\n",
643
+ " ],\n",
644
+ " \"attention_dropout\": 0.0,\n",
645
+ " \"bos_token_id\": 1,\n",
646
+ " \"dtype\": \"float32\",\n",
647
+ " \"eos_token_id\": 2,\n",
648
+ " \"head_dim\": null,\n",
649
+ " \"hidden_act\": \"silu\",\n",
650
+ " \"hidden_size\": 512,\n",
651
+ " \"initializer_range\": 0.02,\n",
652
+ " \"intermediate_size\": 2048,\n",
653
+ " \"max_position_embeddings\": 8192,\n",
654
+ " \"model_type\": \"mistral\",\n",
655
+ " \"num_attention_heads\": 8,\n",
656
+ " \"num_hidden_layers\": 8,\n",
657
+ " \"num_key_value_heads\": 4,\n",
658
+ " \"pad_token_id\": 0,\n",
659
+ " \"rms_norm_eps\": 1e-06,\n",
660
+ " \"rope_theta\": 10000.0,\n",
661
+ " \"sliding_window\": 256,\n",
662
+ " \"tie_word_embeddings\": false,\n",
663
+ " \"transformers_version\": \"4.56.2\",\n",
664
+ " \"use_cache\": true,\n",
665
+ " \"vocab_size\": 24000\n",
666
+ "}\n",
667
+ "\n",
668
+ "loading weights file /workspace/traindata/train/checkpoint-22000/model.safetensors\n",
669
+ "Generate config GenerationConfig {\n",
670
+ " \"bos_token_id\": 1,\n",
671
+ " \"eos_token_id\": 2,\n",
672
+ " \"pad_token_id\": 0\n",
673
+ "}\n",
674
+ "\n",
675
+ "All model checkpoint weights were used when initializing MistralForCausalLM.\n",
676
+ "\n",
677
+ "All the weights of MistralForCausalLM were initialized from the model checkpoint at /workspace/traindata/train/checkpoint-22000/model.safetensors.\n",
678
+ "If your task is similar to the task the model of the checkpoint was trained on, you can already use MistralForCausalLM for predictions without further training.\n",
679
+ "Generation config file not found, using a generation config created from the model config.\n"
680
+ ]
681
+ }
682
+ ],
683
+ "source": [
684
+ "# This is only for training existing models not new ones\n",
685
+ "model_dir = Path(\"/workspace/traindata/train/checkpoint-22000\")\n",
686
+ "\n",
687
+ "config = AutoConfig.from_pretrained(str(model_dir / \"config.json\"))\n",
688
+ "model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=str(model_dir / \"model.safetensors\"), from_tf=False, config=config)"
689
  ]
690
  },
691
  {
 
698
  },
699
  {
700
  "cell_type": "code",
701
+ "execution_count": 12,
702
  "metadata": {},
703
+ "outputs": [
704
+ {
705
+ "name": "stdout",
706
+ "output_type": "stream",
707
+ "text": [
708
+ "/workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run\n"
709
+ ]
710
+ }
711
+ ],
712
  "source": [
713
  "model_dir = root_save / 'run'\n",
714
  "model_dir_str = str(model_dir)\n",
 
719
  "cell_type": "code",
720
  "execution_count": null,
721
  "metadata": {},
722
+ "outputs": [
723
+ {
724
+ "name": "stderr",
725
+ "output_type": "stream",
726
+ "text": [
727
+ "PyTorch: setting up devices\n",
728
+ "average_tokens_across_devices is True but world size is 1. Setting it to False automatically.\n",
729
+ "max_steps is given, it will override any value given in num_train_epochs\n",
730
+ "Using auto half precision backend\n"
731
+ ]
732
+ },
733
+ {
734
+ "name": "stdout",
735
+ "output_type": "stream",
736
+ "text": [
737
+ "True\n"
738
+ ]
739
+ }
740
+ ],
741
  "source": [
742
  "metrics = {metric: load_metric(metric) for metric in [\"accuracy\"]}\n",
743
  "\n",
 
778
  "USE_MPS = not USE_CUDA and mps_available()\n",
779
  "training_config = TrainingArguments(\n",
780
  " model_dir_str, False, True, True, False, \"steps\",\n",
781
+ " per_device_train_batch_size=32, #76% @ 24 batch size #76% @ 32 batch size try 64 batch size next time \n",
782
+ " per_device_eval_batch_size=32, #was 24 now 32\n",
783
  " gradient_accumulation_steps=3, #change this to 4\n",
784
  " eval_accumulation_steps=None,\n",
785
+ " eval_steps=3000,\n",
786
+ " eval_delay=6000,\n",
787
  " learning_rate=1e-4,\n",
788
  " weight_decay=0.01,\n",
789
+ " max_grad_norm=1.0,\n",
790
+ " max_steps=30000,\n",
791
+ " lr_scheduler_type=\"cosine\",\n",
792
+ " warmup_ratio=0.08,\n",
793
  " log_level=\"debug\",\n",
794
  " logging_strategy=\"steps\",\n",
795
+ " logging_steps=100,\n",
796
  " save_strategy=\"steps\",\n",
797
+ " save_steps=3000,\n",
798
  " save_total_limit=5,\n",
799
  " no_cuda=not USE_CUDA,\n",
800
  " seed=444,\n",
 
805
  " load_best_model_at_end=True,\n",
806
  " label_smoothing_factor=0.,\n",
807
  " optim=\"adamw_torch\",\n",
808
+ " report_to=[\"tensorboard\", \"wandb\"],\n",
809
+ " gradient_checkpointing=False,\n",
810
  " dataloader_num_workers=8, #added to fix trashing isssue with the gpu not having enough data to process\n",
811
  " dataloader_pin_memory=True, #we want the dataset in memory\n",
812
+ " torch_compile=False #added to speed up \n",
813
  " \n",
814
  ")\n",
815
  "\n",
 
834
  "metadata": {},
835
  "outputs": [],
836
  "source": [
837
+ "torch.cuda.empty_cache()"
838
+ ]
839
+ },
840
+ {
841
+ "cell_type": "code",
842
+ "execution_count": null,
843
+ "metadata": {},
844
+ "outputs": [],
845
+ "source": [
846
+ "print(model)\n",
847
+ "os.environ['CUDA_LAUNCH_BLOCKING']=\"1\"\n",
848
+ "os.environ['TORCH_USE_CUDA_DSA'] = \"1\""
849
  ]
850
  },
851
  {
 
853
  "execution_count": null,
854
  "metadata": {},
855
  "outputs": [],
856
+ "source": [
857
+ "%env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
858
+ ]
859
+ },
860
+ {
861
+ "cell_type": "code",
862
+ "execution_count": null,
863
+ "metadata": {},
864
+ "outputs": [],
865
+ "source": [
866
+ "seed=1\n",
867
+ "# # print(max(dataset_train[\"input_ids\"].max().item(), 0))\n",
868
+ "\n",
869
+ "torch.manual_seed(seed)\n",
870
+ "\n"
871
+ ]
872
+ },
873
+ {
874
+ "cell_type": "code",
875
+ "execution_count": 24,
876
+ "metadata": {},
877
+ "outputs": [
878
+ {
879
+ "name": "stderr",
880
+ "output_type": "stream",
881
+ "text": [
882
+ "Currently training with a batch size of: 32\n",
883
+ "***** Running training *****\n",
884
+ " Num examples = 5,570,752\n",
885
+ " Num Epochs = 1\n",
886
+ " Instantaneous batch size per device = 32\n",
887
+ " Total train batch size (w. parallel, distributed & accumulation) = 96\n",
888
+ " Gradient Accumulation steps = 3\n",
889
+ " Total optimization steps = 30,000\n",
890
+ " Number of trainable parameters = 56,041,984\n",
891
+ "Automatic Weights & Biases logging enabled, to disable set os.environ[\"WANDB_DISABLED\"] = \"true\"\n",
892
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.\n"
893
+ ]
894
+ },
895
+ {
896
+ "data": {
897
+ "text/html": [
898
+ "\n",
899
+ " <div>\n",
900
+ " \n",
901
+ " <progress value='15166' max='30000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
902
+ " [15166/30000 3:11:26 < 3:07:16, 1.32 it/s, Epoch 0.26/1]\n",
903
+ " </div>\n",
904
+ " <table border=\"1\" class=\"dataframe\">\n",
905
+ " <thead>\n",
906
+ " <tr style=\"text-align: left;\">\n",
907
+ " <th>Step</th>\n",
908
+ " <th>Training Loss</th>\n",
909
+ " <th>Validation Loss</th>\n",
910
+ " <th>Accuracy</th>\n",
911
+ " </tr>\n",
912
+ " </thead>\n",
913
+ " <tbody>\n",
914
+ " <tr>\n",
915
+ " <td>6000</td>\n",
916
+ " <td>1.602200</td>\n",
917
+ " <td>1.751858</td>\n",
918
+ " <td>0.010508</td>\n",
919
+ " </tr>\n",
920
+ " <tr>\n",
921
+ " <td>9000</td>\n",
922
+ " <td>1.584300</td>\n",
923
+ " <td>1.732275</td>\n",
924
+ " <td>0.010426</td>\n",
925
+ " </tr>\n",
926
+ " <tr>\n",
927
+ " <td>12000</td>\n",
928
+ " <td>1.547300</td>\n",
929
+ " <td>1.712772</td>\n",
930
+ " <td>0.010505</td>\n",
931
+ " </tr>\n",
932
+ " <tr>\n",
933
+ " <td>15000</td>\n",
934
+ " <td>1.540700</td>\n",
935
+ " <td>1.694235</td>\n",
936
+ " <td>0.010407</td>\n",
937
+ " </tr>\n",
938
+ " </tbody>\n",
939
+ "</table><p>"
940
+ ],
941
+ "text/plain": [
942
+ "<IPython.core.display.HTML object>"
943
+ ]
944
+ },
945
+ "metadata": {},
946
+ "output_type": "display_data"
947
+ },
948
+ {
949
+ "name": "stderr",
950
+ "output_type": "stream",
951
+ "text": [
952
+ "Saving model checkpoint to /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-3000\n",
953
+ "Configuration saved in /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-3000/config.json\n",
954
+ "Configuration saved in /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-3000/generation_config.json\n",
955
+ "Model weights saved in /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-3000/model.safetensors\n",
956
+ "\n",
957
+ "***** Running Evaluation *****\n",
958
+ " Num examples = 849907\n",
959
+ " Batch size = 32\n",
960
+ "Saving model checkpoint to /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-6000\n",
961
+ "Configuration saved in /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-6000/config.json\n",
962
+ "Configuration saved in /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-6000/generation_config.json\n",
963
+ "Model weights saved in /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-6000/model.safetensors\n",
964
+ "\n",
965
+ "***** Running Evaluation *****\n",
966
+ " Num examples = 849907\n",
967
+ " Batch size = 32\n",
968
+ "Saving model checkpoint to /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-9000\n",
969
+ "Configuration saved in /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-9000/config.json\n",
970
+ "Configuration saved in /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-9000/generation_config.json\n",
971
+ "Model weights saved in /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-9000/model.safetensors\n",
972
+ "\n",
973
+ "***** Running Evaluation *****\n",
974
+ " Num examples = 849907\n",
975
+ " Batch size = 32\n",
976
+ "Saving model checkpoint to /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-12000\n",
977
+ "Configuration saved in /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-12000/config.json\n",
978
+ "Configuration saved in /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-12000/generation_config.json\n",
979
+ "Model weights saved in /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-12000/model.safetensors\n",
980
+ "\n",
981
+ "***** Running Evaluation *****\n",
982
+ " Num examples = 849907\n",
983
+ " Batch size = 32\n",
984
+ "Saving model checkpoint to /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-15000\n",
985
+ "Configuration saved in /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-15000/config.json\n",
986
+ "Configuration saved in /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-15000/generation_config.json\n",
987
+ "Model weights saved in /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-15000/model.safetensors\n"
988
+ ]
989
+ },
990
+ {
991
+ "ename": "KeyboardInterrupt",
992
+ "evalue": "",
993
+ "output_type": "error",
994
+ "traceback": [
995
+ "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
996
+ "\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)",
997
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[24]\u001b[39m\u001b[32m, line 2\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;66;03m# Training\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m2\u001b[39m train_result = \u001b[43mtrainer\u001b[49m\u001b[43m.\u001b[49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 3\u001b[39m trainer.save_model() \u001b[38;5;66;03m# Saves the tokenizer too\u001b[39;00m\n\u001b[32m 4\u001b[39m trainer.log_metrics(\u001b[33m\"\u001b[39m\u001b[33mtrain\u001b[39m\u001b[33m\"\u001b[39m, train_result.metrics)\n",
998
+ "\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/transformers/trainer.py:2328\u001b[39m, in \u001b[36mTrainer.train\u001b[39m\u001b[34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[39m\n\u001b[32m 2326\u001b[39m hf_hub_utils.enable_progress_bars()\n\u001b[32m 2327\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m2328\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 2329\u001b[39m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m=\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2330\u001b[39m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m=\u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2331\u001b[39m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2332\u001b[39m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m=\u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2333\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
999
+ "\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/transformers/trainer.py:2672\u001b[39m, in \u001b[36mTrainer._inner_training_loop\u001b[39m\u001b[34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[39m\n\u001b[32m 2665\u001b[39m context = (\n\u001b[32m 2666\u001b[39m functools.partial(\u001b[38;5;28mself\u001b[39m.accelerator.no_sync, model=model)\n\u001b[32m 2667\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m i != \u001b[38;5;28mlen\u001b[39m(batch_samples) - \u001b[32m1\u001b[39m\n\u001b[32m 2668\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m.accelerator.distributed_type != DistributedType.DEEPSPEED\n\u001b[32m 2669\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m contextlib.nullcontext\n\u001b[32m 2670\u001b[39m )\n\u001b[32m 2671\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m context():\n\u001b[32m-> \u001b[39m\u001b[32m2672\u001b[39m tr_loss_step = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_items_in_batch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 2674\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[32m 2675\u001b[39m args.logging_nan_inf_filter\n\u001b[32m 2676\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_xla_available()\n\u001b[32m 2677\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m (torch.isnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch.isinf(tr_loss_step))\n\u001b[32m 2678\u001b[39m ):\n\u001b[32m 2679\u001b[39m \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[32m 2680\u001b[39m tr_loss = tr_loss + tr_loss / (\u001b[32m1\u001b[39m + \u001b[38;5;28mself\u001b[39m.state.global_step - \u001b[38;5;28mself\u001b[39m._globalstep_last_logged)\n",
1000
+ "\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/transformers/trainer.py:4060\u001b[39m, in \u001b[36mTrainer.training_step\u001b[39m\u001b[34m(***failed resolving arguments***)\u001b[39m\n\u001b[32m 4057\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.accelerator.distributed_type == DistributedType.DEEPSPEED:\n\u001b[32m 4058\u001b[39m kwargs[\u001b[33m\"\u001b[39m\u001b[33mscale_wrt_gas\u001b[39m\u001b[33m\"\u001b[39m] = \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m4060\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43maccelerator\u001b[49m\u001b[43m.\u001b[49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloss\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 4062\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m loss.detach()\n",
1001
+ "\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/accelerate/accelerator.py:2734\u001b[39m, in \u001b[36mAccelerator.backward\u001b[39m\u001b[34m(self, loss, **kwargs)\u001b[39m\n\u001b[32m 2732\u001b[39m \u001b[38;5;28mself\u001b[39m.lomo_backward(loss, learning_rate)\n\u001b[32m 2733\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m2734\u001b[39m \u001b[43mloss\u001b[49m\u001b[43m.\u001b[49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
1002
+ "\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/torch/_tensor.py:648\u001b[39m, in \u001b[36mTensor.backward\u001b[39m\u001b[34m(self, gradient, retain_graph, create_graph, inputs)\u001b[39m\n\u001b[32m 638\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[32m 639\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[32m 640\u001b[39m Tensor.backward,\n\u001b[32m 641\u001b[39m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[32m (...)\u001b[39m\u001b[32m 646\u001b[39m inputs=inputs,\n\u001b[32m 647\u001b[39m )\n\u001b[32m--> \u001b[39m\u001b[32m648\u001b[39m \u001b[43mtorch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mautograd\u001b[49m\u001b[43m.\u001b[49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 649\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m=\u001b[49m\u001b[43minputs\u001b[49m\n\u001b[32m 650\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
1003
+ "\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/torch/autograd/__init__.py:353\u001b[39m, in \u001b[36mbackward\u001b[39m\u001b[34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[39m\n\u001b[32m 348\u001b[39m retain_graph = create_graph\n\u001b[32m 350\u001b[39m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[32m 351\u001b[39m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[32m 352\u001b[39m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m353\u001b[39m \u001b[43m_engine_run_backward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 354\u001b[39m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 355\u001b[39m \u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 356\u001b[39m \u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 357\u001b[39m \u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 358\u001b[39m \u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 359\u001b[39m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[32m 360\u001b[39m \u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[32m 361\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
1004
+ "\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/torch/autograd/graph.py:824\u001b[39m, in \u001b[36m_engine_run_backward\u001b[39m\u001b[34m(t_outputs, *args, **kwargs)\u001b[39m\n\u001b[32m 822\u001b[39m unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)\n\u001b[32m 823\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m824\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mVariable\u001b[49m\u001b[43m.\u001b[49m\u001b[43m_execution_engine\u001b[49m\u001b[43m.\u001b[49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[32m 825\u001b[39m \u001b[43m \u001b[49m\u001b[43mt_outputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\n\u001b[32m 826\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001b[39;00m\n\u001b[32m 827\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[32m 828\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m attach_logging_hooks:\n",
1005
+ "\u001b[31mKeyboardInterrupt\u001b[39m: "
1006
+ ]
1007
+ }
1008
+ ],
1009
  "source": [
1010
  "# Training\n",
1011
  "train_result = trainer.train()\n",
 
1015
  "trainer.save_state()"
1016
  ]
1017
  },
1018
+ {
1019
+ "cell_type": "code",
1020
+ "execution_count": 25,
1021
+ "metadata": {},
1022
+ "outputs": [
1023
+ {
1024
+ "data": {
1025
+ "text/html": [],
1026
+ "text/plain": [
1027
+ "<IPython.core.display.HTML object>"
1028
+ ]
1029
+ },
1030
+ "metadata": {},
1031
+ "output_type": "display_data"
1032
+ },
1033
+ {
1034
+ "data": {
1035
+ "text/html": [
1036
+ "<br> <style><br> .wandb-row {<br> display: flex;<br> flex-direction: row;<br> flex-wrap: wrap;<br> justify-content: flex-start;<br> width: 100%;<br> }<br> .wandb-col {<br> display: flex;<br> flex-direction: column;<br> flex-basis: 100%;<br> flex: 1;<br> padding: 10px;<br> }<br> </style><br><div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>eval/accuracy</td><td>█▂█▁</td></tr><tr><td>eval/loss</td><td>█▆▃▁</td></tr><tr><td>eval/runtime</td><td>█▁▃▆</td></tr><tr><td>eval/samples_per_second</td><td>▁█▆▃</td></tr><tr><td>eval/steps_per_second</td><td>▁█▆▃</td></tr><tr><td>train/epoch</td><td>▁▁▁▁▁▂▂▂▂▁▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇███</td></tr><tr><td>train/global_step</td><td>▁▁▁▁▂▂▂▂▂▂▁▁▁▁▁▂▂▂▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇████</td></tr><tr><td>train/grad_norm</td><td>▅▃▄▅▆█▇▆▁▁▂▂▂▂▁▂▁▂▂▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>train/learning_rate</td><td>▃▅▅▆▆█▁▂▄▄▆▇█████████▇▇▇▇▇▇▇▇▇▆▆▆▆▆▆▆▆▆▅</td></tr><tr><td>train/loss</td><td>█▇▆▆▆▅▄▄▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>eval/accuracy</td><td>0.01041</td></tr><tr><td>eval/loss</td><td>1.69424</td></tr><tr><td>eval/runtime</td><td>1748.5708</td></tr><tr><td>eval/samples_per_second</td><td>486.058</td></tr><tr><td>eval/steps_per_second</td><td>15.19</td></tr><tr><td>train/epoch</td><td>0.26022</td></tr><tr><td>train/global_step</td><td>15100</td></tr><tr><td>train/grad_norm</td><td>0.56171</td></tr><tr><td>train/learning_rate</td><td>6e-05</td></tr><tr><td>train/loss</td><td>1.5401</td></tr></table><br/></div></div>"
1037
+ ],
1038
+ "text/plain": [
1039
+ "<IPython.core.display.HTML object>"
1040
+ ]
1041
+ },
1042
+ "metadata": {},
1043
+ "output_type": "display_data"
1044
+ },
1045
+ {
1046
+ "data": {
1047
+ "text/html": [
1048
+ " View run <strong style=\"color:#cdcd00\">fast-yogurt-6</strong> at: <a href='https://wandb.ai/adric-landman-hobby/midi_music_maker/runs/g1fn393k' target=\"_blank\">https://wandb.ai/adric-landman-hobby/midi_music_maker/runs/g1fn393k</a><br> View project at: <a href='https://wandb.ai/adric-landman-hobby/midi_music_maker' target=\"_blank\">https://wandb.ai/adric-landman-hobby/midi_music_maker</a><br>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
1049
+ ],
1050
+ "text/plain": [
1051
+ "<IPython.core.display.HTML object>"
1052
+ ]
1053
+ },
1054
+ "metadata": {},
1055
+ "output_type": "display_data"
1056
+ },
1057
+ {
1058
+ "data": {
1059
+ "text/html": [
1060
+ "Find logs at: <code></code>"
1061
+ ],
1062
+ "text/plain": [
1063
+ "<IPython.core.display.HTML object>"
1064
+ ]
1065
+ },
1066
+ "metadata": {},
1067
+ "output_type": "display_data"
1068
+ }
1069
+ ],
1070
+ "source": [
1071
+ "wandb.finish()"
1072
+ ]
1073
+ },
1074
  {
1075
  "cell_type": "code",
1076
  "execution_count": null,
 
1083
  },
1084
  {
1085
  "cell_type": "code",
1086
+ "execution_count": 26,
1087
  "metadata": {},
1088
+ "outputs": [
1089
+ {
1090
+ "ename": "HfHubHTTPError",
1091
+ "evalue": "401 Client Error: Unauthorized for url: https://huggingface.co/api/repos/create (Request ID: Root=1-68d628b1-575691d056937c56340182c7;4f0ea033-8e43-4260-938f-d74b744942be)\n\nInvalid credentials in Authorization header",
1092
+ "output_type": "error",
1093
+ "traceback": [
1094
+ "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
1095
+ "\u001b[31mHTTPError\u001b[39m Traceback (most recent call last)",
1096
+ "\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_http.py:407\u001b[39m, in \u001b[36mhf_raise_for_status\u001b[39m\u001b[34m(response, endpoint_name)\u001b[39m\n\u001b[32m 406\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m407\u001b[39m \u001b[43mresponse\u001b[49m\u001b[43m.\u001b[49m\u001b[43mraise_for_status\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 408\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m HTTPError \u001b[38;5;28;01mas\u001b[39;00m e:\n",
1097
+ "\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/requests/models.py:1024\u001b[39m, in \u001b[36mResponse.raise_for_status\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 1023\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m http_error_msg:\n\u001b[32m-> \u001b[39m\u001b[32m1024\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m HTTPError(http_error_msg, response=\u001b[38;5;28mself\u001b[39m)\n",
1098
+ "\u001b[31mHTTPError\u001b[39m: 401 Client Error: Unauthorized for url: https://huggingface.co/api/repos/create",
1099
+ "\nThe above exception was the direct cause of the following exception:\n",
1100
+ "\u001b[31mHfHubHTTPError\u001b[39m Traceback (most recent call last)",
1101
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[26]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m model.hub_model_id = \u001b[33m\"\u001b[39m\u001b[33madricl/midi_single_instrument_mistral_transformer\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m.\u001b[49m\u001b[43mpush_to_hub\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcommit_message\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mTraining Basic Model for Mistral MidiTok Transformer Single Instrument Small\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrepo_id\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43madricl/midi_single_instrument_mistral_transformer\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 4\u001b[39m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
1102
+ "\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py:4346\u001b[39m, in \u001b[36mPreTrainedModel.push_to_hub\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 4344\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m tags:\n\u001b[32m 4345\u001b[39m kwargs[\u001b[33m\"\u001b[39m\u001b[33mtags\u001b[39m\u001b[33m\"\u001b[39m] = tags\n\u001b[32m-> \u001b[39m\u001b[32m4346\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mpush_to_hub\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
1103
+ "\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/transformers/utils/hub.py:955\u001b[39m, in \u001b[36mPushToHubMixin.push_to_hub\u001b[39m\u001b[34m(self, repo_id, use_temp_dir, commit_message, private, token, max_shard_size, create_pr, safe_serialization, revision, commit_description, tags, **deprecated_kwargs)\u001b[39m\n\u001b[32m 952\u001b[39m repo_url = deprecated_kwargs.pop(\u001b[33m\"\u001b[39m\u001b[33mrepo_url\u001b[39m\u001b[33m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[32m 953\u001b[39m organization = deprecated_kwargs.pop(\u001b[33m\"\u001b[39m\u001b[33morganization\u001b[39m\u001b[33m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[32m--> \u001b[39m\u001b[32m955\u001b[39m repo_id = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_create_repo\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 956\u001b[39m \u001b[43m \u001b[49m\u001b[43mrepo_id\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprivate\u001b[49m\u001b[43m=\u001b[49m\u001b[43mprivate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrepo_url\u001b[49m\u001b[43m=\u001b[49m\u001b[43mrepo_url\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43morganization\u001b[49m\u001b[43m=\u001b[49m\u001b[43morganization\u001b[49m\n\u001b[32m 957\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 959\u001b[39m \u001b[38;5;66;03m# Create a new empty model card and eventually tag it\u001b[39;00m\n\u001b[32m 960\u001b[39m model_card = create_and_tag_model_card(\n\u001b[32m 961\u001b[39m repo_id, tags, token=token, ignore_metadata_errors=ignore_metadata_errors\n\u001b[32m 962\u001b[39m )\n",
1104
+ "\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/transformers/utils/hub.py:759\u001b[39m, in \u001b[36mPushToHubMixin._create_repo\u001b[39m\u001b[34m(self, repo_id, private, token, repo_url, organization)\u001b[39m\n\u001b[32m 756\u001b[39m repo_id = repo_id.split(\u001b[33m\"\u001b[39m\u001b[33m/\u001b[39m\u001b[33m\"\u001b[39m)[-\u001b[32m1\u001b[39m]\n\u001b[32m 757\u001b[39m repo_id = \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00morganization\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mrepo_id\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m\n\u001b[32m--> \u001b[39m\u001b[32m759\u001b[39m url = \u001b[43mcreate_repo\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrepo_id\u001b[49m\u001b[43m=\u001b[49m\u001b[43mrepo_id\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprivate\u001b[49m\u001b[43m=\u001b[49m\u001b[43mprivate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mexist_ok\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[32m 760\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m url.repo_id\n",
1105
+ "\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_validators.py:114\u001b[39m, in \u001b[36mvalidate_hf_hub_args.<locals>._inner_fn\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 111\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m check_use_auth_token:\n\u001b[32m 112\u001b[39m kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.\u001b[34m__name__\u001b[39m, has_token=has_token, kwargs=kwargs)\n\u001b[32m--> \u001b[39m\u001b[32m114\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
1106
+ "\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/huggingface_hub/hf_api.py:3766\u001b[39m, in \u001b[36mHfApi.create_repo\u001b[39m\u001b[34m(self, repo_id, token, private, repo_type, exist_ok, resource_group_id, space_sdk, space_hardware, space_storage, space_sleep_time, space_secrets, space_variables)\u001b[39m\n\u001b[32m 3763\u001b[39m \u001b[38;5;28;01mbreak\u001b[39;00m\n\u001b[32m 3765\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m3766\u001b[39m \u001b[43mhf_raise_for_status\u001b[49m\u001b[43m(\u001b[49m\u001b[43mr\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 3767\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m HTTPError \u001b[38;5;28;01mas\u001b[39;00m err:\n\u001b[32m 3768\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m exist_ok \u001b[38;5;129;01mand\u001b[39;00m err.response.status_code == \u001b[32m409\u001b[39m:\n\u001b[32m 3769\u001b[39m \u001b[38;5;66;03m# Repo already exists and `exist_ok=True`\u001b[39;00m\n",
1107
+ "\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_http.py:480\u001b[39m, in \u001b[36mhf_raise_for_status\u001b[39m\u001b[34m(response, endpoint_name)\u001b[39m\n\u001b[32m 476\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m _format(HfHubHTTPError, message, response) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01me\u001b[39;00m\n\u001b[32m 478\u001b[39m \u001b[38;5;66;03m# Convert `HTTPError` into a `HfHubHTTPError` to display request information\u001b[39;00m\n\u001b[32m 479\u001b[39m \u001b[38;5;66;03m# as well (request id and/or server error message)\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m480\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m _format(HfHubHTTPError, \u001b[38;5;28mstr\u001b[39m(e), response) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01me\u001b[39;00m\n",
1108
+ "\u001b[31mHfHubHTTPError\u001b[39m: 401 Client Error: Unauthorized for url: https://huggingface.co/api/repos/create (Request ID: Root=1-68d628b1-575691d056937c56340182c7;4f0ea033-8e43-4260-938f-d74b744942be)\n\nInvalid credentials in Authorization header"
1109
+ ]
1110
+ }
1111
+ ],
1112
  "source": [
1113
  "\n",
1114
  "model.hub_model_id = \"adricl/midi_single_instrument_mistral_transformer\"\n",
 
1238
  "name": "python",
1239
  "nbconvert_exporter": "python",
1240
  "pygments_lexer": "ipython3",
1241
+ "version": "3.11.11"
1242
  },
1243
  "vscode": {
1244
  "interpreter": {
train_tokenizer.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from miditok import REMI, TokenizerConfig
2
+ from random import shuffle, sample
3
+ from pathlib import Path
4
+
5
+ # Our tokenizer's configuration
6
+ BEAT_RES = {(0, 1): 12, (1, 2): 4, (2, 4): 2, (4, 8): 1}
7
+ TOKENIZER_PARAMS = {
8
+ "pitch_range": (21, 108),
9
+ "beat_res": BEAT_RES,
10
+ "num_velocities": 32,
11
+ "special_tokens": ["PAD", "BOS", "EOS"],
12
+ "use_chords": True,
13
+ "use_rests": True,
14
+ "use_tempos": True,
15
+ "use_time_signatures": True,
16
+ "use_programs": False, # We want single track
17
+ "one_token_stream_for_programs": False, # We want single track
18
+ "programs": list(range(0, 128)), #-1 drums, skip drums
19
+ "num_tempos": 32,
20
+ "tempo_range": (40, 250), # (min_tempo, max_tempo)
21
+ }
22
+ config = TokenizerConfig(**TOKENIZER_PARAMS)
23
+
24
+ # Creates the tokenizer REMI PLUS
25
+ tokenizer = REMI(config)
26
+
27
+ root_data_dir = Path('/root')
28
+ root_save = Path(root_data_dir / 'HuggingFace_Mistral_Transformer_Single_Instrument')
29
+
30
+ tokenizer_name = "HuggingFace_Mistral_Transformer_Single_Instrument_v4_single_track.json"
31
+
32
+
33
+
34
+ data_dirs = ["MIDIs"]
35
+ midi_paths = []
36
+ for data_dir in data_dirs:
37
+ path = Path(root_data_dir / data_dir)
38
+ midi_paths.extend(list(path.resolve().glob("**/*.mid")) + list(path.resolve().glob("**/*.midi")))
39
+
40
+ print(f"Found {len(midi_paths)} MIDI files")
41
+
42
+ midi_paths = midi_paths.shuffle()
43
+
44
+ # We need a subset of files otherwise training tokenizer takes too long
45
+ percentage_to_select = 0.20
46
+ num_files_to_select = int(len(midi_paths) * percentage_to_select)
47
+
48
+ subset_midi_paths = sample(midi_paths, num_files_to_select)
49
+ print(f"Found {len(subset_midi_paths)} MIDI files")
50
+
51
+
52
+ tokenizer.train(
53
+ vocab_size=24000,
54
+ files_paths=subset_midi_paths,
55
+ )
56
+ tokenizer.save(root_save / tokenizer_name)