Added new traing data
Browse files
.ipynb_checkpoints/HuggingFace_Mistral_Transformer_Single_Instrument-checkpoint.ipynb
ADDED
|
@@ -0,0 +1,803 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"attachments": {},
|
| 5 |
+
"cell_type": "markdown",
|
| 6 |
+
"metadata": {
|
| 7 |
+
"id": "SiTIpPjArIyr"
|
| 8 |
+
},
|
| 9 |
+
"source": [
|
| 10 |
+
"# Using Midi traning data and MidiTok Remi to generate music with Mistral model \n",
|
| 11 |
+
"# split music into Single Instrument and split into 1024\n"
|
| 12 |
+
]
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"attachments": {},
|
| 16 |
+
"cell_type": "markdown",
|
| 17 |
+
"metadata": {
|
| 18 |
+
"id": "gOd93yV0sGd2"
|
| 19 |
+
},
|
| 20 |
+
"source": [
|
| 21 |
+
"## Setup Environment"
|
| 22 |
+
]
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"cell_type": "code",
|
| 26 |
+
"execution_count": null,
|
| 27 |
+
"metadata": {},
|
| 28 |
+
"outputs": [],
|
| 29 |
+
"source": [
|
| 30 |
+
"To compile Symusic \n",
|
| 31 |
+
"\n",
|
| 32 |
+
"Get g++11 or higher\n",
|
| 33 |
+
"\n",
|
| 34 |
+
"git clone --recursive https://github.com/Yikai-Liao/symusic\n",
|
| 35 |
+
"CXX=/usr/bin/g++-11 pip install ./symusic\n"
|
| 36 |
+
]
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"cell_type": "code",
|
| 40 |
+
"execution_count": null,
|
| 41 |
+
"metadata": {},
|
| 42 |
+
"outputs": [],
|
| 43 |
+
"source": [
|
| 44 |
+
"%pip install torch==2.6.0\n",
|
| 45 |
+
"%pip install evaluate transformers[torch]==4.55.4 tqdm miditok accelerate tensorboardX scikit-learn\n"
|
| 46 |
+
]
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"cell_type": "code",
|
| 50 |
+
"execution_count": 1,
|
| 51 |
+
"metadata": {
|
| 52 |
+
"cellView": "form",
|
| 53 |
+
"id": "fX12Yquyuihc"
|
| 54 |
+
},
|
| 55 |
+
"outputs": [
|
| 56 |
+
{
|
| 57 |
+
"name": "stderr",
|
| 58 |
+
"output_type": "stream",
|
| 59 |
+
"text": [
|
| 60 |
+
"2025-09-12 09:17:37.410013: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
|
| 61 |
+
"To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
|
| 62 |
+
"2025-09-12 09:17:38.509451: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
|
| 63 |
+
]
|
| 64 |
+
}
|
| 65 |
+
],
|
| 66 |
+
"source": [
|
| 67 |
+
"\n",
|
| 68 |
+
"\n",
|
| 69 |
+
"from copy import deepcopy\n",
|
| 70 |
+
"from pathlib import Path\n",
|
| 71 |
+
"from random import shuffle, sample\n",
|
| 72 |
+
"\n",
|
| 73 |
+
"from evaluate import load as load_metric\n",
|
| 74 |
+
"from miditok import REMI, TokenizerConfig, TokTrainingIterator\n",
|
| 75 |
+
"from miditok.pytorch_data import DatasetMIDI, DataCollator\n",
|
| 76 |
+
"from miditok.utils import split_files_for_training\n",
|
| 77 |
+
"\n",
|
| 78 |
+
"from miditok.data_augmentation import augment_dataset\n",
|
| 79 |
+
"from torch import Tensor, argmax, torch\n",
|
| 80 |
+
"from torch.utils.data import DataLoader\n",
|
| 81 |
+
"from torch.cuda import is_available as cuda_available, is_bf16_supported\n",
|
| 82 |
+
"from torch.backends.mps import is_available as mps_available\n",
|
| 83 |
+
"from transformers import AutoModelForCausalLM, MistralConfig, Trainer, TrainingArguments, GenerationConfig, AutoConfig\n",
|
| 84 |
+
"from transformers.trainer_utils import set_seed\n",
|
| 85 |
+
"from tqdm import tqdm"
|
| 86 |
+
]
|
| 87 |
+
},
|
| 88 |
+
{
|
| 89 |
+
"attachments": {},
|
| 90 |
+
"cell_type": "markdown",
|
| 91 |
+
"metadata": {},
|
| 92 |
+
"source": [
|
| 93 |
+
"## Setup Tokenizer"
|
| 94 |
+
]
|
| 95 |
+
},
|
| 96 |
+
{
|
| 97 |
+
"cell_type": "code",
|
| 98 |
+
"execution_count": 2,
|
| 99 |
+
"metadata": {},
|
| 100 |
+
"outputs": [],
|
| 101 |
+
"source": [
|
| 102 |
+
"# Seed\n",
|
| 103 |
+
"set_seed(777)\n",
|
| 104 |
+
"\n",
|
| 105 |
+
"# Our tokenizer's configuration\n",
|
| 106 |
+
"BEAT_RES = {(0, 1): 12, (1, 2): 4, (2, 4): 2, (4, 8): 1}\n",
|
| 107 |
+
"TOKENIZER_PARAMS = {\n",
|
| 108 |
+
" \"pitch_range\": (21, 108),\n",
|
| 109 |
+
" \"beat_res\": BEAT_RES,\n",
|
| 110 |
+
" \"num_velocities\": 32,\n",
|
| 111 |
+
" \"special_tokens\": [\"PAD\", \"BOS\", \"EOS\"],\n",
|
| 112 |
+
" \"use_chords\": True,\n",
|
| 113 |
+
" \"use_rests\": True,\n",
|
| 114 |
+
" \"use_tempos\": True,\n",
|
| 115 |
+
" \"use_time_signatures\": True,\n",
|
| 116 |
+
" \"use_programs\": False, # We want single track \n",
|
| 117 |
+
" \"one_token_stream_for_programs\": False, # We want single track\n",
|
| 118 |
+
" \"programs\": list(range(0, 128)), #-1 drums, skip drums\n",
|
| 119 |
+
" \"num_tempos\": 32,\n",
|
| 120 |
+
" \"tempo_range\": (40, 250), # (min_tempo, max_tempo)\n",
|
| 121 |
+
"}\n",
|
| 122 |
+
"config = TokenizerConfig(**TOKENIZER_PARAMS)\n",
|
| 123 |
+
"\n",
|
| 124 |
+
"# Creates the tokenizer REMI PLUS\n",
|
| 125 |
+
"tokenizer = REMI(config)"
|
| 126 |
+
]
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"cell_type": "markdown",
|
| 130 |
+
"metadata": {},
|
| 131 |
+
"source": [
|
| 132 |
+
"# Load Midi filed and train the the tokenizer on the midi files"
|
| 133 |
+
]
|
| 134 |
+
},
|
| 135 |
+
{
|
| 136 |
+
"cell_type": "code",
|
| 137 |
+
"execution_count": 3,
|
| 138 |
+
"metadata": {},
|
| 139 |
+
"outputs": [],
|
| 140 |
+
"source": [
|
| 141 |
+
"root_data_dir = Path('/home/wombat/Documents/projects/music/midiTok/data/')\n",
|
| 142 |
+
"root_save = Path(root_data_dir / 'HuggingFace_Mistral_Transformer_Single_Instrument')\n",
|
| 143 |
+
"\n",
|
| 144 |
+
"tokenizer_name = \"HuggingFace_Mistral_Transformer_Single_Instrument_v4_single_track.json\"\n",
|
| 145 |
+
"dataset_dir = root_save / \"data\"\n",
|
| 146 |
+
"dataset_dir.mkdir(parents=True, exist_ok=True)"
|
| 147 |
+
]
|
| 148 |
+
},
|
| 149 |
+
{
|
| 150 |
+
"cell_type": "code",
|
| 151 |
+
"execution_count": null,
|
| 152 |
+
"metadata": {},
|
| 153 |
+
"outputs": [],
|
| 154 |
+
"source": [
|
| 155 |
+
"\n",
|
| 156 |
+
"# Trains the tokenizer with Byte Pair Encoding (BPE) to build the vocabulary, here 30k tokens\n",
|
| 157 |
+
"#data_dirs = [\"adl-piano-midi\", \"maestro-v3.0.0\", \"musicnet_midis\" ] # for single \n",
|
| 158 |
+
"data_dirs = [\"MIDIs\"]\n",
|
| 159 |
+
"midi_paths = []\n",
|
| 160 |
+
"for data_dir in data_dirs:\n",
|
| 161 |
+
" path = Path(root_data_dir / 'Traning Data' / data_dir)\n",
|
| 162 |
+
" midi_paths.extend(list(path.resolve().glob(\"**/*.mid\")) + list(path.resolve().glob(\"**/*.midi\")))\n",
|
| 163 |
+
"\n",
|
| 164 |
+
"print(f\"Found {len(midi_paths)} MIDI files\")\n",
|
| 165 |
+
"\n",
|
| 166 |
+
"shuffle(midi_paths)\n",
|
| 167 |
+
"\n",
|
| 168 |
+
"# We need a subset of files otherwise training tokenizer takes too long\n",
|
| 169 |
+
"percentage_to_select = 0.15\n",
|
| 170 |
+
"num_files_to_select = int(len(midi_paths) * percentage_to_select)\n",
|
| 171 |
+
"\n",
|
| 172 |
+
"subset_midi_paths = sample(midi_paths, num_files_to_select)\n",
|
| 173 |
+
"print(f\"Found {len(subset_midi_paths)} MIDI files\")"
|
| 174 |
+
]
|
| 175 |
+
},
|
| 176 |
+
{
|
| 177 |
+
"cell_type": "code",
|
| 178 |
+
"execution_count": null,
|
| 179 |
+
"metadata": {},
|
| 180 |
+
"outputs": [],
|
| 181 |
+
"source": [
|
| 182 |
+
"#Note the size of the dataset is quite large, so it requires a huge amount of memory to train the tokenizer for 61749 files it took 64gb of memory\n",
|
| 183 |
+
"tokenizer.train(\n",
|
| 184 |
+
" vocab_size=24000,\n",
|
| 185 |
+
" files_paths=subset_midi_paths,\n",
|
| 186 |
+
")\n",
|
| 187 |
+
"tokenizer.save(root_save / tokenizer_name)\n",
|
| 188 |
+
"\n"
|
| 189 |
+
]
|
| 190 |
+
},
|
| 191 |
+
{
|
| 192 |
+
"cell_type": "code",
|
| 193 |
+
"execution_count": 4,
|
| 194 |
+
"metadata": {},
|
| 195 |
+
"outputs": [],
|
| 196 |
+
"source": [
|
| 197 |
+
"tokenizer = REMI(params=Path(root_save / tokenizer_name))\n"
|
| 198 |
+
]
|
| 199 |
+
},
|
| 200 |
+
{
|
| 201 |
+
"cell_type": "markdown",
|
| 202 |
+
"metadata": {},
|
| 203 |
+
"source": [
|
| 204 |
+
"## Prepare MIDIs for training\n",
|
| 205 |
+
"\n",
|
| 206 |
+
"Here we split the files in three subsets: train, validation and test.\n",
|
| 207 |
+
"Then data augmentation is performed on each subset independently, and the MIDIs are split into smaller chunks that make approximately the desired token sequence length for training."
|
| 208 |
+
]
|
| 209 |
+
},
|
| 210 |
+
{
|
| 211 |
+
"cell_type": "code",
|
| 212 |
+
"execution_count": 5,
|
| 213 |
+
"metadata": {},
|
| 214 |
+
"outputs": [],
|
| 215 |
+
"source": [
|
| 216 |
+
"sequence_length = 1024 # The maximum sequence length for data samples.\n",
|
| 217 |
+
"kwargs_dataset = {\"max_seq_len\": sequence_length, \"tokenizer\": tokenizer, \"bos_token_id\": tokenizer[\"BOS_None\"], \"eos_token_id\": tokenizer[\"EOS_None\"], \"pre_tokenize\": True, \"pre_tokenize_thread_count\": 7}"
|
| 218 |
+
]
|
| 219 |
+
},
|
| 220 |
+
{
|
| 221 |
+
"cell_type": "markdown",
|
| 222 |
+
"metadata": {},
|
| 223 |
+
"source": [
|
| 224 |
+
"# Test splitting files for training and testing purposes"
|
| 225 |
+
]
|
| 226 |
+
},
|
| 227 |
+
{
|
| 228 |
+
"cell_type": "code",
|
| 229 |
+
"execution_count": null,
|
| 230 |
+
"metadata": {},
|
| 231 |
+
"outputs": [],
|
| 232 |
+
"source": [
|
| 233 |
+
"from pathlib import Path\n",
|
| 234 |
+
"# Split will need to add the BPM to the files its split\n",
|
| 235 |
+
"# \n",
|
| 236 |
+
"file_paths_test = [\n",
|
| 237 |
+
" Path('/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Fatboy Slim/Right Here, Right Now.mid'),\n",
|
| 238 |
+
" Path('/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Fatboy Slim/Praise You.mid'),\n",
|
| 239 |
+
" Path('/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Goo Goo Dolls/Iris.mid'),\n",
|
| 240 |
+
" Path('/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Goo Goo Dolls/Slide.mid'),\n",
|
| 241 |
+
" Path('/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/James Brown/Sex Machine (Get Up I Feel Like Being A).mid'),\n",
|
| 242 |
+
" Path('/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Jamiroquai/Virtual Insanity.1.mid'),\n",
|
| 243 |
+
" Path('/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Jamiroquai/Virtual Insanity.mid')\n",
|
| 244 |
+
"]\n",
|
| 245 |
+
"\n",
|
| 246 |
+
"split_files_for_training(\n",
|
| 247 |
+
" files_paths=file_paths_test,\n",
|
| 248 |
+
" tokenizer=tokenizer,\n",
|
| 249 |
+
" save_dir=Path('/home/wombat/Documents/projects/music/midiTok/data/HuggingFace_Mistral_Transformer_Single_Instrument/test'),\n",
|
| 250 |
+
" max_seq_len=sequence_length,\n",
|
| 251 |
+
" num_overlap_bars=2,\n",
|
| 252 |
+
" skip_drums=True\n",
|
| 253 |
+
")"
|
| 254 |
+
]
|
| 255 |
+
},
|
| 256 |
+
{
|
| 257 |
+
"cell_type": "code",
|
| 258 |
+
"execution_count": null,
|
| 259 |
+
"metadata": {},
|
| 260 |
+
"outputs": [],
|
| 261 |
+
"source": [
|
| 262 |
+
"# Split MIDI paths in train/valid/test sets\n",
|
| 263 |
+
"total_num_files = len(midi_paths)\n",
|
| 264 |
+
"\n",
|
| 265 |
+
"num_files_valid = round(total_num_files * 0.15)\n",
|
| 266 |
+
"num_files_test = round(total_num_files * 0.15)\n",
|
| 267 |
+
"shuffle(midi_paths)\n",
|
| 268 |
+
"midi_paths_valid = midi_paths[:num_files_valid]\n",
|
| 269 |
+
"midi_paths_test = midi_paths[num_files_valid:num_files_valid + num_files_test]\n",
|
| 270 |
+
"midi_paths_train = midi_paths[num_files_valid + num_files_test:]\n",
|
| 271 |
+
"\n",
|
| 272 |
+
"\n",
|
| 273 |
+
"\n",
|
| 274 |
+
"# Chunk MIDIs and perform data augmentation on each subset independently\n",
|
| 275 |
+
"for files_paths, subset_name in (\n",
|
| 276 |
+
" (midi_paths_train, \"train\"), (midi_paths_valid, \"valid\"), (midi_paths_test, \"test\")\n",
|
| 277 |
+
"):\n",
|
| 278 |
+
"\n",
|
| 279 |
+
" # Split the MIDIs into chunks of sizes approximately about 1024 tokens\n",
|
| 280 |
+
" subset_chunks_dir = root_save / f\"Maestro_{subset_name}\"\n",
|
| 281 |
+
" print(subset_chunks_dir)\n",
|
| 282 |
+
" split_files_for_training(\n",
|
| 283 |
+
" files_paths=files_paths,\n",
|
| 284 |
+
" tokenizer=tokenizer,\n",
|
| 285 |
+
" save_dir=subset_chunks_dir,\n",
|
| 286 |
+
" max_seq_len=sequence_length,\n",
|
| 287 |
+
" num_overlap_bars=2,\n",
|
| 288 |
+
" skip_drums=True\n",
|
| 289 |
+
" )\n",
|
| 290 |
+
"\n",
|
| 291 |
+
" if subset_name == 'train':\n",
|
| 292 |
+
" print(\"Augmentation\")\n",
|
| 293 |
+
" # Perform data augmentation\n",
|
| 294 |
+
" augment_dataset(\n",
|
| 295 |
+
" subset_chunks_dir,\n",
|
| 296 |
+
" pitch_offsets=[-12, 12],\n",
|
| 297 |
+
" velocity_offsets=[-4, 4],\n",
|
| 298 |
+
" duration_offsets=[-0.5, 0.5],\n",
|
| 299 |
+
" )\n"
|
| 300 |
+
]
|
| 301 |
+
},
|
| 302 |
+
{
|
| 303 |
+
"cell_type": "code",
|
| 304 |
+
"execution_count": 6,
|
| 305 |
+
"metadata": {},
|
| 306 |
+
"outputs": [],
|
| 307 |
+
"source": [
|
| 308 |
+
"#Since the datasets are too large after splitting we only want 50% of the split data to train against\n",
|
| 309 |
+
"sample_subset_per = .5"
|
| 310 |
+
]
|
| 311 |
+
},
|
| 312 |
+
{
|
| 313 |
+
"cell_type": "code",
|
| 314 |
+
"execution_count": 8,
|
| 315 |
+
"metadata": {},
|
| 316 |
+
"outputs": [
|
| 317 |
+
{
|
| 318 |
+
"ename": "TypeError",
|
| 319 |
+
"evalue": "slice indices must be integers or None or have an __index__ method",
|
| 320 |
+
"output_type": "error",
|
| 321 |
+
"traceback": [
|
| 322 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 323 |
+
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
|
| 324 |
+
"Cell \u001b[0;32mIn[8], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Create Dataset and Collator for training\u001b[39;00m\n\u001b[1;32m 2\u001b[0m midi_paths_train \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(root_save\u001b[38;5;241m.\u001b[39mjoinpath(Path(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMaestro_train\u001b[39m\u001b[38;5;124m\"\u001b[39m))\u001b[38;5;241m.\u001b[39mglob(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m**/*.mid\u001b[39m\u001b[38;5;124m\"\u001b[39m)) \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mlist\u001b[39m(root_save\u001b[38;5;241m.\u001b[39mjoinpath(Path(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMaestro_train\u001b[39m\u001b[38;5;124m\"\u001b[39m))\u001b[38;5;241m.\u001b[39mglob(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m**/*.midi\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[0;32m----> 3\u001b[0m midi_paths_train \u001b[38;5;241m=\u001b[39m \u001b[43mmidi_paths_train\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mmidi_paths_train\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43msample_subset_per\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;28mlen\u001b[39m(midi_paths_train))\n\u001b[1;32m 5\u001b[0m dataset_train \u001b[38;5;241m=\u001b[39m DatasetMIDI(midi_paths_train, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs_dataset)\n",
|
| 325 |
+
"\u001b[0;31mTypeError\u001b[0m: slice indices must be integers or None or have an __index__ method"
|
| 326 |
+
]
|
| 327 |
+
}
|
| 328 |
+
],
|
| 329 |
+
"source": [
|
| 330 |
+
"# Create Dataset and Collator for training\n",
|
| 331 |
+
"midi_paths_train = list(root_save.joinpath(Path(\"Maestro_train\")).glob(\"**/*.mid\")) + list(root_save.joinpath(Path(\"Maestro_train\")).glob(\"**/*.midi\"))\n",
|
| 332 |
+
"sample_count = (len(midi_paths_train)*sample_subset_per)\n",
|
| 333 |
+
"midi_paths_train = midi_paths_train[:]\n",
|
| 334 |
+
"print(len(midi_paths_train))\n",
|
| 335 |
+
"dataset_train = DatasetMIDI(midi_paths_train, **kwargs_dataset)\n",
|
| 336 |
+
"torch.save(dataset_train, Path(dataset_dir / \"dataset_train.pt\"))"
|
| 337 |
+
]
|
| 338 |
+
},
|
| 339 |
+
{
|
| 340 |
+
"cell_type": "code",
|
| 341 |
+
"execution_count": null,
|
| 342 |
+
"metadata": {},
|
| 343 |
+
"outputs": [],
|
| 344 |
+
"source": [
|
| 345 |
+
"midi_paths_valid = list(root_save.joinpath(Path(\"Maestro_valid\")).glob(\"**/*.mid\")) + list(root_save.joinpath(Path(\"Maestro_valid\")).glob(\"**/*.midi\")) \n",
|
| 346 |
+
"midi_paths_valid = midi_paths_valid[:(len(midi_paths_valid)*sample_subset_per]\n",
|
| 347 |
+
"print(len(midi_paths_valid))\n",
|
| 348 |
+
"dataset_valid = DatasetMIDI(midi_paths_valid, **kwargs_dataset)\n",
|
| 349 |
+
"torch.save(dataset_valid, Path(dataset_dir / \"dataset_valid.pt\"))"
|
| 350 |
+
]
|
| 351 |
+
},
|
| 352 |
+
{
|
| 353 |
+
"cell_type": "code",
|
| 354 |
+
"execution_count": null,
|
| 355 |
+
"metadata": {},
|
| 356 |
+
"outputs": [],
|
| 357 |
+
"source": [
|
| 358 |
+
"midi_paths_test = list(root_save.joinpath(Path(\"Maestro_test\")).glob(\"**/*.mid\")) + list(root_save.joinpath(Path(\"Maestro_test\")).glob(\"**/*.midi\"))\n",
|
| 359 |
+
"midi_paths_test = midi_paths_test[:(len(midi_paths_test)*sample_subset_per]\n",
|
| 360 |
+
"print(len(midi_paths_test))\n",
|
| 361 |
+
"dataset_test = DatasetMIDI(midi_paths_test, **kwargs_dataset)\n",
|
| 362 |
+
"torch.save(dataset_test, Path(dataset_dir / \"dataset_test.pt\"))\n"
|
| 363 |
+
]
|
| 364 |
+
},
|
| 365 |
+
{
|
| 366 |
+
"cell_type": "code",
|
| 367 |
+
"execution_count": null,
|
| 368 |
+
"metadata": {},
|
| 369 |
+
"outputs": [],
|
| 370 |
+
"source": [
|
| 371 |
+
"print (len(midi_paths_train), len(midi_paths_valid), len(midi_paths_test))\n"
|
| 372 |
+
]
|
| 373 |
+
},
|
| 374 |
+
{
|
| 375 |
+
"cell_type": "markdown",
|
| 376 |
+
"metadata": {},
|
| 377 |
+
"source": [
|
| 378 |
+
"# Save and Load datasets"
|
| 379 |
+
]
|
| 380 |
+
},
|
| 381 |
+
{
|
| 382 |
+
"cell_type": "code",
|
| 383 |
+
"execution_count": null,
|
| 384 |
+
"metadata": {},
|
| 385 |
+
"outputs": [],
|
| 386 |
+
"source": [
|
| 387 |
+
"\n",
|
| 388 |
+
"dataset_train = torch.load(Path(dataset_dir / \"dataset_train.pt\"), weights_only=False)\n",
|
| 389 |
+
"dataset_valid = torch.load(Path(dataset_dir / \"dataset_valid.pt\"), weights_only=False)\n",
|
| 390 |
+
"dataset_test = torch.load(Path(dataset_dir / \"dataset_test.pt\"), weights_only=False)\n",
|
| 391 |
+
"\n"
|
| 392 |
+
]
|
| 393 |
+
},
|
| 394 |
+
{
|
| 395 |
+
"cell_type": "code",
|
| 396 |
+
"execution_count": null,
|
| 397 |
+
"metadata": {},
|
| 398 |
+
"outputs": [],
|
| 399 |
+
"source": [
|
| 400 |
+
"import pickle\n",
|
| 401 |
+
"\n",
|
| 402 |
+
"test_file = open(Path(dataset_dir / \"dataset_test.pickle\"), 'ab')\n",
|
| 403 |
+
"pickle.dump(dataset_test, test_file)\n",
|
| 404 |
+
"test_file.close()\n",
|
| 405 |
+
"\n",
|
| 406 |
+
"print(dataset_test[0])\n",
|
| 407 |
+
"\n",
|
| 408 |
+
"test_file = open(Path(dataset_dir / \"dataset_test.pickle\"), 'rb')\n",
|
| 409 |
+
"test_pickle = pickle.load(test_file)\n",
|
| 410 |
+
"print(test_pickle)\n",
|
| 411 |
+
"print(test_pickle[0])\n",
|
| 412 |
+
"\n",
|
| 413 |
+
"\n"
|
| 414 |
+
]
|
| 415 |
+
},
|
| 416 |
+
{
|
| 417 |
+
"cell_type": "markdown",
|
| 418 |
+
"metadata": {},
|
| 419 |
+
"source": [
|
| 420 |
+
"# Preview files data load and split"
|
| 421 |
+
]
|
| 422 |
+
},
|
| 423 |
+
{
|
| 424 |
+
"cell_type": "code",
|
| 425 |
+
"execution_count": null,
|
| 426 |
+
"metadata": {},
|
| 427 |
+
"outputs": [],
|
| 428 |
+
"source": [
|
| 429 |
+
"\n",
|
| 430 |
+
"#testing_files = \n",
|
| 431 |
+
"preview_files_path = []\n",
|
| 432 |
+
"for testing_file in testing_files:\n",
|
| 433 |
+
" preview_files_path.append(Path(testing_file))\n",
|
| 434 |
+
"\n",
|
| 435 |
+
"preview_dir = Path(root_save / \"preview\")\n",
|
| 436 |
+
"split_files_for_training(\n",
|
| 437 |
+
" files_paths=preview_files_path,\n",
|
| 438 |
+
" tokenizer=tokenizer,\n",
|
| 439 |
+
" save_dir=preview_dir,\n",
|
| 440 |
+
" max_seq_len=sequence_length,\n",
|
| 441 |
+
" num_overlap_bars=2,\n",
|
| 442 |
+
" )\n",
|
| 443 |
+
"\n",
|
| 444 |
+
"valid_midi_path = root_save / \"Maestro_valid\"\n",
|
| 445 |
+
"midi_split_preview = list(valid_midi_path.resolve().glob(\"**/*.mid\")) + list(valid_midi_path.resolve().glob(\"**/*.midi\"))\n",
|
| 446 |
+
"\n",
|
| 447 |
+
"print(len(midi_split_preview))\n",
|
| 448 |
+
"file_name_lookup = []\n",
|
| 449 |
+
"def func_to_get_labels(p1, p2, p3):\n",
|
| 450 |
+
" if p3.name not in file_name_lookup:\n",
|
| 451 |
+
" file_name_lookup.append(p3.name)\n",
|
| 452 |
+
" return file_name_lookup.index(p3.name)\n",
|
| 453 |
+
" \n",
|
| 454 |
+
"kwargs_dataset = {\"max_seq_len\": sequence_length, \"tokenizer\": tokenizer, \"bos_token_id\": tokenizer[\"BOS_None\"], \"eos_token_id\": tokenizer[\"EOS_None\"], \"func_to_get_labels\" : func_to_get_labels}\n",
|
| 455 |
+
"dataset_preview = DatasetMIDI(midi_split_preview, **kwargs_dataset)"
|
| 456 |
+
]
|
| 457 |
+
},
|
| 458 |
+
{
|
| 459 |
+
"attachments": {},
|
| 460 |
+
"cell_type": "markdown",
|
| 461 |
+
"metadata": {},
|
| 462 |
+
"source": [
|
| 463 |
+
"## Model initialization\n",
|
| 464 |
+
"\n",
|
| 465 |
+
"We will use the [Mistral implementation of Hugging Face](https://huggingface.co/docs/transformers/model_doc/mistral).\n",
|
| 466 |
+
"Feel free to explore the documentation and source code to dig deeper.\n",
|
| 467 |
+
"\n",
|
| 468 |
+
"**You may need to adjust the model's configuration, the training configuration and the maximum input sequence length (cell above) depending on your hardware.**"
|
| 469 |
+
]
|
| 470 |
+
},
|
| 471 |
+
{
|
| 472 |
+
"cell_type": "code",
|
| 473 |
+
"execution_count": null,
|
| 474 |
+
"metadata": {},
|
| 475 |
+
"outputs": [],
|
| 476 |
+
"source": [
|
| 477 |
+
"# Creates model\n",
|
| 478 |
+
"model_config = MistralConfig(\n",
|
| 479 |
+
" vocab_size=len(tokenizer), #from miditok output default 32K\n",
|
| 480 |
+
" hidden_size=512, # default 4096\n",
|
| 481 |
+
" intermediate_size=2048, # default 14336\n",
|
| 482 |
+
" num_hidden_layers=8, # default 32\n",
|
| 483 |
+
" num_attention_heads=8, # default 32\n",
|
| 484 |
+
" num_key_value_heads=4, # default 8\n",
|
| 485 |
+
" sliding_window=256, # default 4096\n",
|
| 486 |
+
" max_position_embeddings=8192, #has no effect on the parms count or training just limits the input length # default 4096*32\n",
|
| 487 |
+
" pad_token_id=tokenizer['PAD_None'],\n",
|
| 488 |
+
" bos_token_id=tokenizer['BOS_None'],\n",
|
| 489 |
+
" eos_token_id=tokenizer['EOS_None'],\n",
|
| 490 |
+
")\n",
|
| 491 |
+
"model = AutoModelForCausalLM.from_config(model_config)"
|
| 492 |
+
]
|
| 493 |
+
},
|
| 494 |
+
{
|
| 495 |
+
"attachments": {},
|
| 496 |
+
"cell_type": "markdown",
|
| 497 |
+
"metadata": {},
|
| 498 |
+
"source": [
|
| 499 |
+
"## Model training"
|
| 500 |
+
]
|
| 501 |
+
},
|
| 502 |
+
{
|
| 503 |
+
"cell_type": "code",
|
| 504 |
+
"execution_count": null,
|
| 505 |
+
"metadata": {},
|
| 506 |
+
"outputs": [],
|
| 507 |
+
"source": [
|
| 508 |
+
"model_dir = root_save / 'run'\n",
|
| 509 |
+
"model_dir_str = str(model_dir)\n",
|
| 510 |
+
"print(model_dir)"
|
| 511 |
+
]
|
| 512 |
+
},
|
| 513 |
+
{
|
| 514 |
+
"cell_type": "code",
|
| 515 |
+
"execution_count": null,
|
| 516 |
+
"metadata": {},
|
| 517 |
+
"outputs": [],
|
| 518 |
+
"source": [
|
| 519 |
+
"metrics = {metric: load_metric(metric) for metric in [\"accuracy\"]}\n",
|
| 520 |
+
"\n",
|
| 521 |
+
"def compute_metrics(eval_pred):\n",
|
| 522 |
+
" \"\"\"\n",
|
| 523 |
+
" Compute metrics for pretraining.\n",
|
| 524 |
+
"\n",
|
| 525 |
+
" Must use preprocess_logits function that converts logits to predictions (argmax or sampling).\n",
|
| 526 |
+
"\n",
|
| 527 |
+
" :param eval_pred: EvalPrediction containing predictions and labels\n",
|
| 528 |
+
" :return: metrics\n",
|
| 529 |
+
" \"\"\"\n",
|
| 530 |
+
" predictions, labels = eval_pred\n",
|
| 531 |
+
" not_pad_mask = labels != -100\n",
|
| 532 |
+
" labels, predictions = labels[not_pad_mask], predictions[not_pad_mask]\n",
|
| 533 |
+
" return metrics[\"accuracy\"].compute(predictions=predictions.flatten(), references=labels.flatten())\n",
|
| 534 |
+
"\n",
|
| 535 |
+
"def preprocess_logits(logits: Tensor, _: Tensor) -> Tensor:\n",
|
| 536 |
+
" \"\"\"\n",
|
| 537 |
+
" Preprocess the logits before accumulating them during evaluation.\n",
|
| 538 |
+
"\n",
|
| 539 |
+
" This allows to significantly reduce the memory usage and make the training tractable.\n",
|
| 540 |
+
" \"\"\"\n",
|
| 541 |
+
" pred_ids = argmax(logits, dim=-1) # long dtype\n",
|
| 542 |
+
" return pred_ids\n",
|
| 543 |
+
"\n",
|
| 544 |
+
"# Create config for the Trainer\n",
|
| 545 |
+
"USE_CUDA = cuda_available()\n",
|
| 546 |
+
"print(USE_CUDA)\n",
|
| 547 |
+
"if not cuda_available():\n",
|
| 548 |
+
" FP16 = FP16_EVAL = BF16 = BF16_EVAL = False\n",
|
| 549 |
+
"elif is_bf16_supported():\n",
|
| 550 |
+
" BF16 = BF16_EVAL = True\n",
|
| 551 |
+
" FP16 = FP16_EVAL = False\n",
|
| 552 |
+
"else:\n",
|
| 553 |
+
" BF16 = BF16_EVAL = False\n",
|
| 554 |
+
" FP16 = FP16_EVAL = True\n",
|
| 555 |
+
"USE_MPS = not USE_CUDA and mps_available()\n",
|
| 556 |
+
"training_config = TrainingArguments(\n",
|
| 557 |
+
" model_dir_str, False, True, True, False, \"steps\",\n",
|
| 558 |
+
" per_device_train_batch_size=24, #76% @ 24 batch size #76% @ 32 batch size try 64 batch size next time \n",
|
| 559 |
+
" per_device_eval_batch_size=24, #was 24 now 32\n",
|
| 560 |
+
" gradient_accumulation_steps=3, #change this to 4\n",
|
| 561 |
+
" eval_accumulation_steps=None,\n",
|
| 562 |
+
" eval_steps=1000,\n",
|
| 563 |
+
" learning_rate=1e-4,\n",
|
| 564 |
+
" weight_decay=0.01,\n",
|
| 565 |
+
" max_grad_norm=3.0,\n",
|
| 566 |
+
" max_steps=40000,\n",
|
| 567 |
+
" lr_scheduler_type=\"cosine_with_restarts\",\n",
|
| 568 |
+
" warmup_ratio=0.3,\n",
|
| 569 |
+
" log_level=\"debug\",\n",
|
| 570 |
+
" logging_strategy=\"steps\",\n",
|
| 571 |
+
" logging_steps=20,\n",
|
| 572 |
+
" save_strategy=\"steps\",\n",
|
| 573 |
+
" save_steps=1000,\n",
|
| 574 |
+
" save_total_limit=5,\n",
|
| 575 |
+
" no_cuda=not USE_CUDA,\n",
|
| 576 |
+
" seed=444,\n",
|
| 577 |
+
" fp16=FP16,\n",
|
| 578 |
+
" fp16_full_eval=FP16_EVAL,\n",
|
| 579 |
+
" bf16=BF16,\n",
|
| 580 |
+
" bf16_full_eval=BF16_EVAL,\n",
|
| 581 |
+
" load_best_model_at_end=True,\n",
|
| 582 |
+
" label_smoothing_factor=0.,\n",
|
| 583 |
+
" optim=\"adamw_torch\",\n",
|
| 584 |
+
" report_to=[\"tensorboard\"],\n",
|
| 585 |
+
" gradient_checkpointing=True,\n",
|
| 586 |
+
" dataloader_num_workers=8, #added to fix trashing isssue with the gpu not having enough data to process\n",
|
| 587 |
+
" dataloader_pin_memory=True, #we want the dataset in memory\n",
|
| 588 |
+
" torch_compile=True #added to speed up \n",
|
| 589 |
+
" \n",
|
| 590 |
+
")\n",
|
| 591 |
+
"\n",
|
| 592 |
+
"collator = DataCollator(tokenizer[\"PAD_None\"], copy_inputs_as_labels=True, pad_on_left=True) #not sure about the pad_on_left, it might get better results\n",
|
| 593 |
+
"trainer = Trainer(\n",
|
| 594 |
+
" model=model,\n",
|
| 595 |
+
" args=training_config,\n",
|
| 596 |
+
" data_collator=collator,\n",
|
| 597 |
+
" train_dataset=dataset_train,\n",
|
| 598 |
+
" eval_dataset=dataset_valid,\n",
|
| 599 |
+
" compute_metrics=compute_metrics,\n",
|
| 600 |
+
" callbacks=None,\n",
|
| 601 |
+
" preprocess_logits_for_metrics=preprocess_logits,\n",
|
| 602 |
+
" \n",
|
| 603 |
+
")\n",
|
| 604 |
+
"\n"
|
| 605 |
+
]
|
| 606 |
+
},
|
| 607 |
+
{
|
| 608 |
+
"cell_type": "code",
|
| 609 |
+
"execution_count": null,
|
| 610 |
+
"metadata": {},
|
| 611 |
+
"outputs": [],
|
| 612 |
+
"source": [
|
| 613 |
+
"del model\n",
|
| 614 |
+
"torch.cuda.empty_cache()"
|
| 615 |
+
]
|
| 616 |
+
},
|
| 617 |
+
{
|
| 618 |
+
"cell_type": "code",
|
| 619 |
+
"execution_count": null,
|
| 620 |
+
"metadata": {},
|
| 621 |
+
"outputs": [],
|
| 622 |
+
"source": [
|
| 623 |
+
"print(model)\n"
|
| 624 |
+
]
|
| 625 |
+
},
|
| 626 |
+
{
|
| 627 |
+
"cell_type": "code",
|
| 628 |
+
"execution_count": null,
|
| 629 |
+
"metadata": {},
|
| 630 |
+
"outputs": [],
|
| 631 |
+
"source": [
|
| 632 |
+
"%env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
|
| 633 |
+
]
|
| 634 |
+
},
|
| 635 |
+
{
|
| 636 |
+
"cell_type": "code",
|
| 637 |
+
"execution_count": null,
|
| 638 |
+
"metadata": {},
|
| 639 |
+
"outputs": [],
|
| 640 |
+
"source": [
|
| 641 |
+
"# Training\n",
|
| 642 |
+
"train_result = trainer.train()\n",
|
| 643 |
+
"trainer.save_model() # Saves the tokenizer too\n",
|
| 644 |
+
"trainer.log_metrics(\"train\", train_result.metrics)\n",
|
| 645 |
+
"trainer.save_metrics(\"train\", train_result.metrics)\n",
|
| 646 |
+
"trainer.save_state()"
|
| 647 |
+
]
|
| 648 |
+
},
|
| 649 |
+
{
|
| 650 |
+
"cell_type": "code",
|
| 651 |
+
"execution_count": null,
|
| 652 |
+
"metadata": {},
|
| 653 |
+
"outputs": [],
|
| 654 |
+
"source": [
|
| 655 |
+
"model.create_model_card(tags=[\"mistral\", \"midi\", \"miditok\", \"music\", \"instrument\"],\n",
|
| 656 |
+
" model_name=\"Mistral_MidiTok_Transformer_Single_Instrument_Small\")"
|
| 657 |
+
]
|
| 658 |
+
},
|
| 659 |
+
{
|
| 660 |
+
"cell_type": "code",
|
| 661 |
+
"execution_count": null,
|
| 662 |
+
"metadata": {},
|
| 663 |
+
"outputs": [],
|
| 664 |
+
"source": [
|
| 665 |
+
"\n",
|
| 666 |
+
"model.hub_model_id = \"adricl/midi_single_instrument_mistral_transformer\"\n",
|
| 667 |
+
"\n",
|
| 668 |
+
"model.push_to_hub(commit_message=\"Training Basic Model for Mistral MidiTok Transformer Single Instrument Small\", repo_id=\"adricl/midi_single_instrument_mistral_transformer\",\n",
|
| 669 |
+
" token=\"\")\n"
|
| 670 |
+
]
|
| 671 |
+
},
|
| 672 |
+
{
|
| 673 |
+
"cell_type": "markdown",
|
| 674 |
+
"metadata": {},
|
| 675 |
+
"source": [
|
| 676 |
+
"# For Tensorboard tensorboard --logdir runs/"
|
| 677 |
+
]
|
| 678 |
+
},
|
| 679 |
+
{
|
| 680 |
+
"cell_type": "code",
|
| 681 |
+
"execution_count": null,
|
| 682 |
+
"metadata": {},
|
| 683 |
+
"outputs": [],
|
| 684 |
+
"source": [
|
| 685 |
+
"config = AutoConfig.from_pretrained(str(model_dir / \"config.json\"))\n",
|
| 686 |
+
"model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=str(model_dir / \"model.safetensors\"), from_tf=False, config=config)"
|
| 687 |
+
]
|
| 688 |
+
},
|
| 689 |
+
{
|
| 690 |
+
"attachments": {},
|
| 691 |
+
"cell_type": "markdown",
|
| 692 |
+
"metadata": {},
|
| 693 |
+
"source": [
|
| 694 |
+
"## Generate music"
|
| 695 |
+
]
|
| 696 |
+
},
|
| 697 |
+
{
|
| 698 |
+
"cell_type": "code",
|
| 699 |
+
"execution_count": null,
|
| 700 |
+
"metadata": {
|
| 701 |
+
"cellView": "form",
|
| 702 |
+
"id": "OaNkGcFo9UP_"
|
| 703 |
+
},
|
| 704 |
+
"outputs": [],
|
| 705 |
+
"source": [
|
| 706 |
+
"# for single track midi files splits \n",
|
| 707 |
+
"\n",
|
| 708 |
+
"gen_results_path = root_save / 'gen_res'\n",
|
| 709 |
+
"gen_results_path.mkdir(parents=True, exist_ok=True)\n",
|
| 710 |
+
"generation_config = GenerationConfig(\n",
|
| 711 |
+
" max_new_tokens=200, # extends samples by 200 tokens\n",
|
| 712 |
+
" num_beams=1, # no beam search\n",
|
| 713 |
+
" do_sample=True, # but sample instead\n",
|
| 714 |
+
" temperature=0.9,\n",
|
| 715 |
+
" top_k=15,\n",
|
| 716 |
+
" top_p=0.95,\n",
|
| 717 |
+
" epsilon_cutoff=3e-4,\n",
|
| 718 |
+
" eta_cutoff=1e-3,\n",
|
| 719 |
+
" pad_token_id=tokenizer.pad_token_id,\n",
|
| 720 |
+
")\n",
|
| 721 |
+
"\n",
|
| 722 |
+
"# Here the sequences are padded to the left, so that the last token along the time dimension\n",
|
| 723 |
+
"# is always the last token of each seq, allowing to efficiently generate by batch\n",
|
| 724 |
+
"collator.pad_on_left = True\n",
|
| 725 |
+
"collator.eos_token = None\n",
|
| 726 |
+
"dataloader_test = DataLoader(dataset_preview, batch_size=24, collate_fn=collator)\n",
|
| 727 |
+
"model.eval()\n",
|
| 728 |
+
"count = 0\n",
|
| 729 |
+
"for batch in tqdm(dataloader_test, desc='Testing model / Generating results'): # (N,T)\n",
|
| 730 |
+
" print(batch)\n",
|
| 731 |
+
" res = model.generate(\n",
|
| 732 |
+
" inputs=batch[\"input_ids\"].to(model.device),\n",
|
| 733 |
+
" attention_mask=batch[\"attention_mask\"].to(model.device),\n",
|
| 734 |
+
" generation_config=generation_config) # (N,T)\n",
|
| 735 |
+
"\n",
|
| 736 |
+
"\n",
|
| 737 |
+
" # Saves the generated music, as MIDI files and tokens (json)\n",
|
| 738 |
+
" for prompt, continuation in zip(batch[\"input_ids\"], res):\n",
|
| 739 |
+
" generated = continuation[len(prompt):]\n",
|
| 740 |
+
" midi = tokenizer.decode([deepcopy(generated.tolist())])\n",
|
| 741 |
+
" tokens = [generated, prompt, continuation] # list compr. as seqs of dif. lengths\n",
|
| 742 |
+
" tokens = [seq.tolist() for seq in tokens]\n",
|
| 743 |
+
" for tok_seq in tokens[1:]:\n",
|
| 744 |
+
" _midi = tokenizer.decode([deepcopy(tok_seq)])\n",
|
| 745 |
+
" midi.tracks.append(_midi.tracks[0])\n",
|
| 746 |
+
" \n",
|
| 747 |
+
" file_name = file_name_lookup[count]\n",
|
| 748 |
+
" print(file_name)\n",
|
| 749 |
+
" midi.tracks[0].name = f'Continuation of original sample ({len(generated)} tokens) Original file {file_name}'\n",
|
| 750 |
+
" midi.tracks[1].name = f'Original sample ({len(prompt)} tokens)'\n",
|
| 751 |
+
" if (len(midi.tracks) > 2):\n",
|
| 752 |
+
" midi.tracks[2].name = f'Original sample and continuation'\n",
|
| 753 |
+
" midi.dump_midi(gen_results_path / f'{count}_{file_name}.mid')\n",
|
| 754 |
+
" tokenizer.save_tokens(tokens, gen_results_path / f'{count}_{file_name}.json') \n",
|
| 755 |
+
"\n",
|
| 756 |
+
" count += 1"
|
| 757 |
+
]
|
| 758 |
+
},
|
| 759 |
+
{
|
| 760 |
+
"cell_type": "code",
|
| 761 |
+
"execution_count": null,
|
| 762 |
+
"metadata": {},
|
| 763 |
+
"outputs": [],
|
| 764 |
+
"source": [
|
| 765 |
+
"print(file_name_lookup)"
|
| 766 |
+
]
|
| 767 |
+
}
|
| 768 |
+
],
|
| 769 |
+
"metadata": {
|
| 770 |
+
"accelerator": "GPU",
|
| 771 |
+
"colab": {
|
| 772 |
+
"collapsed_sections": [],
|
| 773 |
+
"machine_shape": "hm",
|
| 774 |
+
"name": "Optimus_VIRTUOSO_Multi_Instrumental_RGA_Edition.ipynb",
|
| 775 |
+
"private_outputs": true,
|
| 776 |
+
"provenance": []
|
| 777 |
+
},
|
| 778 |
+
"kernelspec": {
|
| 779 |
+
"display_name": "Python 3 (ipykernel)",
|
| 780 |
+
"language": "python",
|
| 781 |
+
"name": "python3"
|
| 782 |
+
},
|
| 783 |
+
"language_info": {
|
| 784 |
+
"codemirror_mode": {
|
| 785 |
+
"name": "ipython",
|
| 786 |
+
"version": 3
|
| 787 |
+
},
|
| 788 |
+
"file_extension": ".py",
|
| 789 |
+
"mimetype": "text/x-python",
|
| 790 |
+
"name": "python",
|
| 791 |
+
"nbconvert_exporter": "python",
|
| 792 |
+
"pygments_lexer": "ipython3",
|
| 793 |
+
"version": "3.9.5"
|
| 794 |
+
},
|
| 795 |
+
"vscode": {
|
| 796 |
+
"interpreter": {
|
| 797 |
+
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
|
| 798 |
+
}
|
| 799 |
+
}
|
| 800 |
+
},
|
| 801 |
+
"nbformat": 4,
|
| 802 |
+
"nbformat_minor": 4
|
| 803 |
+
}
|
HuggingFace_Mistral_Transformer_Single_Instrument.ipynb
CHANGED
|
@@ -37,32 +37,181 @@
|
|
| 37 |
},
|
| 38 |
{
|
| 39 |
"cell_type": "code",
|
| 40 |
-
"execution_count":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
"metadata": {
|
| 42 |
"cellView": "form",
|
| 43 |
"id": "fX12Yquyuihc"
|
| 44 |
},
|
| 45 |
"outputs": [],
|
| 46 |
"source": [
|
| 47 |
-
"\n",
|
| 48 |
-
"\n",
|
| 49 |
"from copy import deepcopy\n",
|
| 50 |
"from pathlib import Path\n",
|
| 51 |
-
"from random import shuffle\n",
|
| 52 |
"\n",
|
| 53 |
"from evaluate import load as load_metric\n",
|
| 54 |
-
"from miditok import REMI, TokenizerConfig
|
| 55 |
"from miditok.pytorch_data import DatasetMIDI, DataCollator\n",
|
| 56 |
"from miditok.utils import split_files_for_training\n",
|
| 57 |
"\n",
|
| 58 |
"from miditok.data_augmentation import augment_dataset\n",
|
| 59 |
-
"from torch import Tensor, argmax\n",
|
| 60 |
"from torch.utils.data import DataLoader\n",
|
| 61 |
"from torch.cuda import is_available as cuda_available, is_bf16_supported\n",
|
| 62 |
"from torch.backends.mps import is_available as mps_available\n",
|
| 63 |
"from transformers import AutoModelForCausalLM, MistralConfig, Trainer, TrainingArguments, GenerationConfig, AutoConfig\n",
|
| 64 |
"from transformers.trainer_utils import set_seed\n",
|
| 65 |
-
"from tqdm import tqdm"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
]
|
| 67 |
},
|
| 68 |
{
|
|
@@ -75,13 +224,10 @@
|
|
| 75 |
},
|
| 76 |
{
|
| 77 |
"cell_type": "code",
|
| 78 |
-
"execution_count":
|
| 79 |
"metadata": {},
|
| 80 |
"outputs": [],
|
| 81 |
"source": [
|
| 82 |
-
"# Seed\n",
|
| 83 |
-
"set_seed(777)\n",
|
| 84 |
-
"\n",
|
| 85 |
"# Our tokenizer's configuration\n",
|
| 86 |
"BEAT_RES = {(0, 1): 12, (1, 2): 4, (2, 4): 2, (4, 8): 1}\n",
|
| 87 |
"TOKENIZER_PARAMS = {\n",
|
|
@@ -114,14 +260,16 @@
|
|
| 114 |
},
|
| 115 |
{
|
| 116 |
"cell_type": "code",
|
| 117 |
-
"execution_count":
|
| 118 |
"metadata": {},
|
| 119 |
"outputs": [],
|
| 120 |
"source": [
|
| 121 |
-
"root_data_dir = Path('/
|
| 122 |
"root_save = Path(root_data_dir / 'HuggingFace_Mistral_Transformer_Single_Instrument')\n",
|
| 123 |
"\n",
|
| 124 |
-
"tokenizer_name = \"
|
|
|
|
|
|
|
| 125 |
]
|
| 126 |
},
|
| 127 |
{
|
|
@@ -132,13 +280,23 @@
|
|
| 132 |
"source": [
|
| 133 |
"\n",
|
| 134 |
"# Trains the tokenizer with Byte Pair Encoding (BPE) to build the vocabulary, here 30k tokens\n",
|
| 135 |
-
"data_dirs = [\"adl-piano-midi\", \"maestro-v3.0.0\", \"musicnet_midis\" ] # for single \n",
|
|
|
|
| 136 |
"midi_paths = []\n",
|
| 137 |
"for data_dir in data_dirs:\n",
|
| 138 |
" path = Path(root_data_dir / 'Traning Data' / data_dir)\n",
|
| 139 |
" midi_paths.extend(list(path.resolve().glob(\"**/*.mid\")) + list(path.resolve().glob(\"**/*.midi\")))\n",
|
| 140 |
"\n",
|
| 141 |
-
"print(f\"Found {len(midi_paths)} MIDI files\")"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
]
|
| 143 |
},
|
| 144 |
{
|
|
@@ -149,8 +307,8 @@
|
|
| 149 |
"source": [
|
| 150 |
"#Note the size of the dataset is quite large, so it requires a huge amount of memory to train the tokenizer for 61749 files it took 64gb of memory\n",
|
| 151 |
"tokenizer.train(\n",
|
| 152 |
-
" vocab_size=
|
| 153 |
-
" files_paths=
|
| 154 |
")\n",
|
| 155 |
"tokenizer.save(root_save / tokenizer_name)\n",
|
| 156 |
"\n"
|
|
@@ -158,11 +316,11 @@
|
|
| 158 |
},
|
| 159 |
{
|
| 160 |
"cell_type": "code",
|
| 161 |
-
"execution_count":
|
| 162 |
"metadata": {},
|
| 163 |
"outputs": [],
|
| 164 |
"source": [
|
| 165 |
-
"tokenizer = REMI(params=Path(root_save / tokenizer_name))"
|
| 166 |
]
|
| 167 |
},
|
| 168 |
{
|
|
@@ -177,12 +335,19 @@
|
|
| 177 |
},
|
| 178 |
{
|
| 179 |
"cell_type": "code",
|
| 180 |
-
"execution_count":
|
| 181 |
"metadata": {},
|
| 182 |
"outputs": [],
|
| 183 |
"source": [
|
| 184 |
"sequence_length = 1024 # The maximum sequence length for data samples.\n",
|
| 185 |
-
"kwargs_dataset = {\"max_seq_len\": sequence_length, \"tokenizer\": tokenizer, \"bos_token_id\": tokenizer[\"BOS_None\"], \"eos_token_id\": tokenizer[\"EOS_None\"]}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
]
|
| 187 |
},
|
| 188 |
{
|
|
@@ -191,30 +356,27 @@
|
|
| 191 |
"metadata": {},
|
| 192 |
"outputs": [],
|
| 193 |
"source": [
|
| 194 |
-
"
|
| 195 |
-
"
|
| 196 |
-
"
|
| 197 |
-
"
|
| 198 |
-
"
|
| 199 |
-
"
|
| 200 |
-
"
|
| 201 |
-
"
|
| 202 |
-
"
|
| 203 |
-
"
|
| 204 |
-
"
|
| 205 |
-
"
|
| 206 |
-
" try:\n",
|
| 207 |
-
" scores = [Score(file_path)]\n",
|
| 208 |
-
" except SCORE_LOADING_EXCEPTION:\n",
|
| 209 |
-
" continue\n",
|
| 210 |
-
"\n",
|
| 211 |
-
" for track in scores[0].tracks:\n",
|
| 212 |
-
" values = track.notes['pitch']\n",
|
| 213 |
-
" result = rms(values)\n",
|
| 214 |
-
"\n",
|
| 215 |
"\n",
|
| 216 |
-
"
|
| 217 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
]
|
| 219 |
},
|
| 220 |
{
|
|
@@ -225,6 +387,7 @@
|
|
| 225 |
"source": [
|
| 226 |
"# Split MIDI paths in train/valid/test sets\n",
|
| 227 |
"total_num_files = len(midi_paths)\n",
|
|
|
|
| 228 |
"num_files_valid = round(total_num_files * 0.15)\n",
|
| 229 |
"num_files_test = round(total_num_files * 0.15)\n",
|
| 230 |
"shuffle(midi_paths)\n",
|
|
@@ -248,6 +411,7 @@
|
|
| 248 |
" save_dir=subset_chunks_dir,\n",
|
| 249 |
" max_seq_len=sequence_length,\n",
|
| 250 |
" num_overlap_bars=2,\n",
|
|
|
|
| 251 |
" )\n",
|
| 252 |
"\n",
|
| 253 |
" if subset_name == 'train':\n",
|
|
@@ -261,6 +425,16 @@
|
|
| 261 |
" )\n"
|
| 262 |
]
|
| 263 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
{
|
| 265 |
"cell_type": "code",
|
| 266 |
"execution_count": null,
|
|
@@ -269,47 +443,39 @@
|
|
| 269 |
"source": [
|
| 270 |
"# Create Dataset and Collator for training\n",
|
| 271 |
"midi_paths_train = list(root_save.joinpath(Path(\"Maestro_train\")).glob(\"**/*.mid\")) + list(root_save.joinpath(Path(\"Maestro_train\")).glob(\"**/*.midi\"))\n",
|
| 272 |
-
"
|
| 273 |
-
"
|
| 274 |
"\n",
|
| 275 |
-
"\n",
|
| 276 |
-
"\n",
|
| 277 |
-
"dataset_train = DatasetMIDI(
|
| 278 |
-
"
|
| 279 |
-
"dataset_test = DatasetMIDI(midi_paths_test, **kwargs_dataset)\n",
|
| 280 |
-
"print (len(midi_paths_train), len(midi_paths_valid), len(midi_paths_test))"
|
| 281 |
]
|
| 282 |
},
|
| 283 |
{
|
| 284 |
-
"cell_type": "
|
|
|
|
| 285 |
"metadata": {},
|
|
|
|
| 286 |
"source": [
|
| 287 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
]
|
| 289 |
},
|
| 290 |
{
|
| 291 |
"cell_type": "code",
|
| 292 |
"execution_count": null,
|
| 293 |
-
"metadata": {
|
| 294 |
-
"tags": [
|
| 295 |
-
"Generate Preview Files"
|
| 296 |
-
]
|
| 297 |
-
},
|
| 298 |
"outputs": [],
|
| 299 |
"source": [
|
| 300 |
-
"
|
| 301 |
-
"
|
| 302 |
-
"
|
| 303 |
-
"
|
| 304 |
-
"\n"
|
| 305 |
-
"preview_dir = Path(root_save / \"preview\")\n",
|
| 306 |
-
"split_files_for_training(\n",
|
| 307 |
-
" files_paths=preview_files_path,\n",
|
| 308 |
-
" tokenizer=tokenizer,\n",
|
| 309 |
-
" save_dir=preview_dir,\n",
|
| 310 |
-
" max_seq_len=sequence_length,\n",
|
| 311 |
-
" num_overlap_bars=2,\n",
|
| 312 |
-
" )\n"
|
| 313 |
]
|
| 314 |
},
|
| 315 |
{
|
|
@@ -318,18 +484,7 @@
|
|
| 318 |
"metadata": {},
|
| 319 |
"outputs": [],
|
| 320 |
"source": [
|
| 321 |
-
"
|
| 322 |
-
"midi_split_preview = list(valid_midi_path.resolve().glob(\"**/*.mid\")) + list(valid_midi_path.resolve().glob(\"**/*.midi\"))\n",
|
| 323 |
-
"\n",
|
| 324 |
-
"print(len(midi_split_preview))\n",
|
| 325 |
-
"file_name_lookup = []\n",
|
| 326 |
-
"def func_to_get_labels(p1, p2, p3):\n",
|
| 327 |
-
" if p3.name not in file_name_lookup:\n",
|
| 328 |
-
" file_name_lookup.append(p3.name)\n",
|
| 329 |
-
" return file_name_lookup.index(p3.name)\n",
|
| 330 |
-
" \n",
|
| 331 |
-
"kwargs_dataset = {\"max_seq_len\": sequence_length, \"tokenizer\": tokenizer, \"bos_token_id\": tokenizer[\"BOS_None\"], \"eos_token_id\": tokenizer[\"EOS_None\"], \"func_to_get_labels\" : func_to_get_labels}\n",
|
| 332 |
-
"dataset_preview = DatasetMIDI(midi_split_preview, **kwargs_dataset)"
|
| 333 |
]
|
| 334 |
},
|
| 335 |
{
|
|
@@ -341,12 +496,15 @@
|
|
| 341 |
},
|
| 342 |
{
|
| 343 |
"cell_type": "code",
|
| 344 |
-
"execution_count":
|
| 345 |
"metadata": {},
|
| 346 |
"outputs": [],
|
| 347 |
"source": [
|
| 348 |
-
"
|
| 349 |
-
"dataset_dir.
|
|
|
|
|
|
|
|
|
|
| 350 |
]
|
| 351 |
},
|
| 352 |
{
|
|
@@ -355,23 +513,27 @@
|
|
| 355 |
"metadata": {},
|
| 356 |
"outputs": [],
|
| 357 |
"source": [
|
| 358 |
-
"import
|
| 359 |
-
"
|
| 360 |
-
"
|
| 361 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
]
|
| 363 |
},
|
| 364 |
{
|
| 365 |
-
"cell_type": "
|
| 366 |
-
"execution_count": null,
|
| 367 |
"metadata": {},
|
| 368 |
-
"outputs": [],
|
| 369 |
"source": [
|
| 370 |
-
"
|
| 371 |
-
"dataset_train = torch.load(Path(dataset_dir / \"dataset_train.pt\"))\n",
|
| 372 |
-
"dataset_valid = torch.load(Path(dataset_dir / \"dataset_valid.pt\"))\n",
|
| 373 |
-
"dataset_test = torch.load(Path(dataset_dir / \"dataset_test.pt\"))\n",
|
| 374 |
-
"\n"
|
| 375 |
]
|
| 376 |
},
|
| 377 |
{
|
|
@@ -380,7 +542,33 @@
|
|
| 380 |
"metadata": {},
|
| 381 |
"outputs": [],
|
| 382 |
"source": [
|
| 383 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
]
|
| 385 |
},
|
| 386 |
{
|
|
@@ -398,9 +586,22 @@
|
|
| 398 |
},
|
| 399 |
{
|
| 400 |
"cell_type": "code",
|
| 401 |
-
"execution_count":
|
| 402 |
"metadata": {},
|
| 403 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
"source": [
|
| 405 |
"# Creates model\n",
|
| 406 |
"model_config = MistralConfig(\n",
|
|
@@ -411,12 +612,80 @@
|
|
| 411 |
" num_attention_heads=8, # default 32\n",
|
| 412 |
" num_key_value_heads=4, # default 8\n",
|
| 413 |
" sliding_window=256, # default 4096\n",
|
| 414 |
-
" max_position_embeddings=
|
| 415 |
" pad_token_id=tokenizer['PAD_None'],\n",
|
| 416 |
" bos_token_id=tokenizer['BOS_None'],\n",
|
| 417 |
" eos_token_id=tokenizer['EOS_None'],\n",
|
| 418 |
")\n",
|
| 419 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
]
|
| 421 |
},
|
| 422 |
{
|
|
@@ -429,9 +698,17 @@
|
|
| 429 |
},
|
| 430 |
{
|
| 431 |
"cell_type": "code",
|
| 432 |
-
"execution_count":
|
| 433 |
"metadata": {},
|
| 434 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
"source": [
|
| 436 |
"model_dir = root_save / 'run'\n",
|
| 437 |
"model_dir_str = str(model_dir)\n",
|
|
@@ -442,7 +719,25 @@
|
|
| 442 |
"cell_type": "code",
|
| 443 |
"execution_count": null,
|
| 444 |
"metadata": {},
|
| 445 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
"source": [
|
| 447 |
"metrics = {metric: load_metric(metric) for metric in [\"accuracy\"]}\n",
|
| 448 |
"\n",
|
|
@@ -483,22 +778,23 @@
|
|
| 483 |
"USE_MPS = not USE_CUDA and mps_available()\n",
|
| 484 |
"training_config = TrainingArguments(\n",
|
| 485 |
" model_dir_str, False, True, True, False, \"steps\",\n",
|
| 486 |
-
" per_device_train_batch_size=
|
| 487 |
-
" per_device_eval_batch_size=
|
| 488 |
" gradient_accumulation_steps=3, #change this to 4\n",
|
| 489 |
" eval_accumulation_steps=None,\n",
|
| 490 |
-
" eval_steps=
|
|
|
|
| 491 |
" learning_rate=1e-4,\n",
|
| 492 |
" weight_decay=0.01,\n",
|
| 493 |
-
" max_grad_norm=
|
| 494 |
-
" max_steps=
|
| 495 |
-
" lr_scheduler_type=\"
|
| 496 |
-
" warmup_ratio=0.
|
| 497 |
" log_level=\"debug\",\n",
|
| 498 |
" logging_strategy=\"steps\",\n",
|
| 499 |
-
" logging_steps=
|
| 500 |
" save_strategy=\"steps\",\n",
|
| 501 |
-
" save_steps=
|
| 502 |
" save_total_limit=5,\n",
|
| 503 |
" no_cuda=not USE_CUDA,\n",
|
| 504 |
" seed=444,\n",
|
|
@@ -509,11 +805,11 @@
|
|
| 509 |
" load_best_model_at_end=True,\n",
|
| 510 |
" label_smoothing_factor=0.,\n",
|
| 511 |
" optim=\"adamw_torch\",\n",
|
| 512 |
-
" report_to=[\"tensorboard\"],\n",
|
| 513 |
-
" gradient_checkpointing=
|
| 514 |
" dataloader_num_workers=8, #added to fix trashing isssue with the gpu not having enough data to process\n",
|
| 515 |
" dataloader_pin_memory=True, #we want the dataset in memory\n",
|
| 516 |
-
" torch_compile=
|
| 517 |
" \n",
|
| 518 |
")\n",
|
| 519 |
"\n",
|
|
@@ -538,7 +834,18 @@
|
|
| 538 |
"metadata": {},
|
| 539 |
"outputs": [],
|
| 540 |
"source": [
|
| 541 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 542 |
]
|
| 543 |
},
|
| 544 |
{
|
|
@@ -546,6 +853,159 @@
|
|
| 546 |
"execution_count": null,
|
| 547 |
"metadata": {},
|
| 548 |
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 549 |
"source": [
|
| 550 |
"# Training\n",
|
| 551 |
"train_result = trainer.train()\n",
|
|
@@ -555,6 +1015,62 @@
|
|
| 555 |
"trainer.save_state()"
|
| 556 |
]
|
| 557 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 558 |
{
|
| 559 |
"cell_type": "code",
|
| 560 |
"execution_count": null,
|
|
@@ -567,9 +1083,32 @@
|
|
| 567 |
},
|
| 568 |
{
|
| 569 |
"cell_type": "code",
|
| 570 |
-
"execution_count":
|
| 571 |
"metadata": {},
|
| 572 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
"source": [
|
| 574 |
"\n",
|
| 575 |
"model.hub_model_id = \"adricl/midi_single_instrument_mistral_transformer\"\n",
|
|
@@ -699,7 +1238,7 @@
|
|
| 699 |
"name": "python",
|
| 700 |
"nbconvert_exporter": "python",
|
| 701 |
"pygments_lexer": "ipython3",
|
| 702 |
-
"version": "3.
|
| 703 |
},
|
| 704 |
"vscode": {
|
| 705 |
"interpreter": {
|
|
|
|
| 37 |
},
|
| 38 |
{
|
| 39 |
"cell_type": "code",
|
| 40 |
+
"execution_count": 1,
|
| 41 |
+
"metadata": {},
|
| 42 |
+
"outputs": [
|
| 43 |
+
{
|
| 44 |
+
"name": "stdout",
|
| 45 |
+
"output_type": "stream",
|
| 46 |
+
"text": [
|
| 47 |
+
"Requirement already satisfied: pip in /usr/local/lib/python3.11/dist-packages (25.2)\n"
|
| 48 |
+
]
|
| 49 |
+
},
|
| 50 |
+
{
|
| 51 |
+
"name": "stderr",
|
| 52 |
+
"output_type": "stream",
|
| 53 |
+
"text": [
|
| 54 |
+
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n",
|
| 55 |
+
"\u001b[0m"
|
| 56 |
+
]
|
| 57 |
+
}
|
| 58 |
+
],
|
| 59 |
+
"source": [
|
| 60 |
+
"%%python -m pip install --upgrade pip\n",
|
| 61 |
+
"\n"
|
| 62 |
+
]
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"cell_type": "code",
|
| 66 |
+
"execution_count": 2,
|
| 67 |
+
"metadata": {},
|
| 68 |
+
"outputs": [
|
| 69 |
+
{
|
| 70 |
+
"name": "stdout",
|
| 71 |
+
"output_type": "stream",
|
| 72 |
+
"text": [
|
| 73 |
+
"Requirement already satisfied: evaluate in /usr/local/lib/python3.11/dist-packages (0.4.6)\n",
|
| 74 |
+
"Requirement already satisfied: transformers in /usr/local/lib/python3.11/dist-packages (4.56.2)\n",
|
| 75 |
+
"Requirement already satisfied: tqdm in /usr/local/lib/python3.11/dist-packages (4.67.1)\n",
|
| 76 |
+
"Requirement already satisfied: miditok in /usr/local/lib/python3.11/dist-packages (3.0.6.post1)\n",
|
| 77 |
+
"Requirement already satisfied: accelerate in /usr/local/lib/python3.11/dist-packages (1.10.1)\n",
|
| 78 |
+
"Requirement already satisfied: tensorboardX in /usr/local/lib/python3.11/dist-packages (2.6.4)\n",
|
| 79 |
+
"Requirement already satisfied: scikit-learn in /usr/local/lib/python3.11/dist-packages (1.7.2)\n",
|
| 80 |
+
"Requirement already satisfied: wandb in /usr/local/lib/python3.11/dist-packages (0.22.0)\n",
|
| 81 |
+
"Requirement already satisfied: datasets>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from evaluate) (4.1.1)\n",
|
| 82 |
+
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.11/dist-packages (from evaluate) (2.1.2)\n",
|
| 83 |
+
"Requirement already satisfied: dill in /usr/local/lib/python3.11/dist-packages (from evaluate) (0.4.0)\n",
|
| 84 |
+
"Requirement already satisfied: pandas in /usr/local/lib/python3.11/dist-packages (from evaluate) (2.3.2)\n",
|
| 85 |
+
"Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.11/dist-packages (from evaluate) (2.32.3)\n",
|
| 86 |
+
"Requirement already satisfied: xxhash in /usr/local/lib/python3.11/dist-packages (from evaluate) (3.5.0)\n",
|
| 87 |
+
"Requirement already satisfied: multiprocess in /usr/local/lib/python3.11/dist-packages (from evaluate) (0.70.16)\n",
|
| 88 |
+
"Requirement already satisfied: fsspec>=2021.05.0 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]>=2021.05.0->evaluate) (2024.10.0)\n",
|
| 89 |
+
"Requirement already satisfied: huggingface-hub>=0.7.0 in /usr/local/lib/python3.11/dist-packages (from evaluate) (0.35.1)\n",
|
| 90 |
+
"Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from evaluate) (24.2)\n",
|
| 91 |
+
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from transformers) (3.16.1)\n",
|
| 92 |
+
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.11/dist-packages (from transformers) (6.0.2)\n",
|
| 93 |
+
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers) (2025.9.18)\n",
|
| 94 |
+
"Requirement already satisfied: tokenizers<=0.23.0,>=0.22.0 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.22.1)\n",
|
| 95 |
+
"Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.6.2)\n",
|
| 96 |
+
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.7.0->evaluate) (4.12.2)\n",
|
| 97 |
+
"Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.7.0->evaluate) (1.1.10)\n",
|
| 98 |
+
"Requirement already satisfied: symusic>=0.5.0 in /usr/local/lib/python3.11/dist-packages (from miditok) (0.5.8)\n",
|
| 99 |
+
"Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from accelerate) (7.0.0)\n",
|
| 100 |
+
"Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from accelerate) (2.8.0.dev20250319+cu128)\n",
|
| 101 |
+
"Requirement already satisfied: protobuf>=3.20 in /usr/local/lib/python3.11/dist-packages (from tensorboardX) (6.32.1)\n",
|
| 102 |
+
"Requirement already satisfied: scipy>=1.8.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn) (1.16.2)\n",
|
| 103 |
+
"Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn) (1.5.2)\n",
|
| 104 |
+
"Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn) (3.6.0)\n",
|
| 105 |
+
"Requirement already satisfied: click>=8.0.1 in /usr/local/lib/python3.11/dist-packages (from wandb) (8.3.0)\n",
|
| 106 |
+
"Requirement already satisfied: gitpython!=3.1.29,>=1.0.0 in /usr/local/lib/python3.11/dist-packages (from wandb) (3.1.45)\n",
|
| 107 |
+
"Requirement already satisfied: platformdirs in /usr/local/lib/python3.11/dist-packages (from wandb) (4.3.7)\n",
|
| 108 |
+
"Requirement already satisfied: pydantic<3 in /usr/local/lib/python3.11/dist-packages (from wandb) (2.11.9)\n",
|
| 109 |
+
"Requirement already satisfied: sentry-sdk>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from wandb) (2.39.0)\n",
|
| 110 |
+
"Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from pydantic<3->wandb) (0.7.0)\n",
|
| 111 |
+
"Requirement already satisfied: pydantic-core==2.33.2 in /usr/local/lib/python3.11/dist-packages (from pydantic<3->wandb) (2.33.2)\n",
|
| 112 |
+
"Requirement already satisfied: typing-inspection>=0.4.0 in /usr/local/lib/python3.11/dist-packages (from pydantic<3->wandb) (0.4.1)\n",
|
| 113 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests>=2.19.0->evaluate) (3.4.1)\n",
|
| 114 |
+
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests>=2.19.0->evaluate) (3.10)\n",
|
| 115 |
+
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests>=2.19.0->evaluate) (2.3.0)\n",
|
| 116 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests>=2.19.0->evaluate) (2025.1.31)\n",
|
| 117 |
+
"Requirement already satisfied: pyarrow>=21.0.0 in /usr/local/lib/python3.11/dist-packages (from datasets>=2.0.0->evaluate) (21.0.0)\n",
|
| 118 |
+
"Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]>=2021.05.0->evaluate) (3.12.15)\n",
|
| 119 |
+
"Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (2.6.1)\n",
|
| 120 |
+
"Requirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (1.4.0)\n",
|
| 121 |
+
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (25.3.0)\n",
|
| 122 |
+
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (1.7.0)\n",
|
| 123 |
+
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (6.6.4)\n",
|
| 124 |
+
"Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (0.3.2)\n",
|
| 125 |
+
"Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (1.20.1)\n",
|
| 126 |
+
"Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.11/dist-packages (from gitpython!=3.1.29,>=1.0.0->wandb) (4.0.12)\n",
|
| 127 |
+
"Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.11/dist-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb) (5.0.2)\n",
|
| 128 |
+
"Requirement already satisfied: pySmartDL in /usr/local/lib/python3.11/dist-packages (from symusic>=0.5.0->miditok) (1.3.4)\n",
|
| 129 |
+
"Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (1.13.3)\n",
|
| 130 |
+
"Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (3.4.2)\n",
|
| 131 |
+
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (3.1.4)\n",
|
| 132 |
+
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.8.61 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (12.8.61)\n",
|
| 133 |
+
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.8.57 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (12.8.57)\n",
|
| 134 |
+
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.8.57 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (12.8.57)\n",
|
| 135 |
+
"Requirement already satisfied: nvidia-cudnn-cu12==9.8.0.87 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (9.8.0.87)\n",
|
| 136 |
+
"Requirement already satisfied: nvidia-cublas-cu12==12.8.3.14 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (12.8.3.14)\n",
|
| 137 |
+
"Requirement already satisfied: nvidia-cufft-cu12==11.3.3.41 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (11.3.3.41)\n",
|
| 138 |
+
"Requirement already satisfied: nvidia-curand-cu12==10.3.9.55 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (10.3.9.55)\n",
|
| 139 |
+
"Requirement already satisfied: nvidia-cusolver-cu12==11.7.2.55 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (11.7.2.55)\n",
|
| 140 |
+
"Requirement already satisfied: nvidia-cusparse-cu12==12.5.7.53 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (12.5.7.53)\n",
|
| 141 |
+
"Requirement already satisfied: nvidia-cusparselt-cu12==0.6.3 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (0.6.3)\n",
|
| 142 |
+
"Requirement already satisfied: nvidia-nccl-cu12==2.25.1 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (2.25.1)\n",
|
| 143 |
+
"Requirement already satisfied: nvidia-nvtx-cu12==12.8.55 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (12.8.55)\n",
|
| 144 |
+
"Requirement already satisfied: nvidia-nvjitlink-cu12==12.8.61 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (12.8.61)\n",
|
| 145 |
+
"Requirement already satisfied: nvidia-cufile-cu12==1.13.0.11 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (1.13.0.11)\n",
|
| 146 |
+
"Requirement already satisfied: pytorch-triton==3.3.0+git96316ce5 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate) (3.3.0+git96316ce5)\n",
|
| 147 |
+
"Requirement already satisfied: setuptools>=40.8.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-triton==3.3.0+git96316ce5->torch>=2.0.0->accelerate) (77.0.1)\n",
|
| 148 |
+
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy>=1.13.3->torch>=2.0.0->accelerate) (1.3.0)\n",
|
| 149 |
+
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch>=2.0.0->accelerate) (2.1.5)\n",
|
| 150 |
+
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas->evaluate) (2.9.0.post0)\n",
|
| 151 |
+
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas->evaluate) (2025.2)\n",
|
| 152 |
+
"Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas->evaluate) (2025.2)\n",
|
| 153 |
+
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas->evaluate) (1.16.0)\n",
|
| 154 |
+
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n",
|
| 155 |
+
"\u001b[0mNote: you may need to restart the kernel to use updated packages.\n"
|
| 156 |
+
]
|
| 157 |
+
}
|
| 158 |
+
],
|
| 159 |
+
"source": [
|
| 160 |
+
"%pip install evaluate transformers tqdm miditok accelerate tensorboardX scikit-learn wandb"
|
| 161 |
+
]
|
| 162 |
+
},
|
| 163 |
+
{
|
| 164 |
+
"cell_type": "code",
|
| 165 |
+
"execution_count": 3,
|
| 166 |
"metadata": {
|
| 167 |
"cellView": "form",
|
| 168 |
"id": "fX12Yquyuihc"
|
| 169 |
},
|
| 170 |
"outputs": [],
|
| 171 |
"source": [
|
|
|
|
|
|
|
| 172 |
"from copy import deepcopy\n",
|
| 173 |
"from pathlib import Path\n",
|
| 174 |
+
"from random import shuffle, sample\n",
|
| 175 |
"\n",
|
| 176 |
"from evaluate import load as load_metric\n",
|
| 177 |
+
"from miditok import REMI, TokenizerConfig\n",
|
| 178 |
"from miditok.pytorch_data import DatasetMIDI, DataCollator\n",
|
| 179 |
"from miditok.utils import split_files_for_training\n",
|
| 180 |
"\n",
|
| 181 |
"from miditok.data_augmentation import augment_dataset\n",
|
| 182 |
+
"from torch import Tensor, argmax, torch\n",
|
| 183 |
"from torch.utils.data import DataLoader\n",
|
| 184 |
"from torch.cuda import is_available as cuda_available, is_bf16_supported\n",
|
| 185 |
"from torch.backends.mps import is_available as mps_available\n",
|
| 186 |
"from transformers import AutoModelForCausalLM, MistralConfig, Trainer, TrainingArguments, GenerationConfig, AutoConfig\n",
|
| 187 |
"from transformers.trainer_utils import set_seed\n",
|
| 188 |
+
"from tqdm import tqdm\n",
|
| 189 |
+
"\n",
|
| 190 |
+
"#Seed\n",
|
| 191 |
+
"set_seed(777)"
|
| 192 |
+
]
|
| 193 |
+
},
|
| 194 |
+
{
|
| 195 |
+
"cell_type": "code",
|
| 196 |
+
"execution_count": 4,
|
| 197 |
+
"metadata": {},
|
| 198 |
+
"outputs": [
|
| 199 |
+
{
|
| 200 |
+
"name": "stderr",
|
| 201 |
+
"output_type": "stream",
|
| 202 |
+
"text": [
|
| 203 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33madric-landman\u001b[0m (\u001b[33madric-landman-hobby\u001b[0m) to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
|
| 204 |
+
]
|
| 205 |
+
}
|
| 206 |
+
],
|
| 207 |
+
"source": [
|
| 208 |
+
"# Wand db\n",
|
| 209 |
+
"import wandb\n",
|
| 210 |
+
"import os\n",
|
| 211 |
+
"wandb.login() #6bfb401d368c508cb01a291a8ae84c0ecce2310d \n",
|
| 212 |
+
"\n",
|
| 213 |
+
"os.environ[\"WANDB_PROJECT\"]=\"midi_music_maker\"\n",
|
| 214 |
+
"\n"
|
| 215 |
]
|
| 216 |
},
|
| 217 |
{
|
|
|
|
| 224 |
},
|
| 225 |
{
|
| 226 |
"cell_type": "code",
|
| 227 |
+
"execution_count": 6,
|
| 228 |
"metadata": {},
|
| 229 |
"outputs": [],
|
| 230 |
"source": [
|
|
|
|
|
|
|
|
|
|
| 231 |
"# Our tokenizer's configuration\n",
|
| 232 |
"BEAT_RES = {(0, 1): 12, (1, 2): 4, (2, 4): 2, (4, 8): 1}\n",
|
| 233 |
"TOKENIZER_PARAMS = {\n",
|
|
|
|
| 260 |
},
|
| 261 |
{
|
| 262 |
"cell_type": "code",
|
| 263 |
+
"execution_count": 7,
|
| 264 |
"metadata": {},
|
| 265 |
"outputs": [],
|
| 266 |
"source": [
|
| 267 |
+
"root_data_dir = Path('/workspace/traindata/data')\n",
|
| 268 |
"root_save = Path(root_data_dir / 'HuggingFace_Mistral_Transformer_Single_Instrument')\n",
|
| 269 |
"\n",
|
| 270 |
+
"tokenizer_name = \"HuggingFace_Mistral_Transformer_Single_Instrument_v4_single_track.json\"\n",
|
| 271 |
+
"dataset_dir = root_save / \"data\"\n",
|
| 272 |
+
"dataset_dir.mkdir(parents=True, exist_ok=True)"
|
| 273 |
]
|
| 274 |
},
|
| 275 |
{
|
|
|
|
| 280 |
"source": [
|
| 281 |
"\n",
|
| 282 |
"# Trains the tokenizer with Byte Pair Encoding (BPE) to build the vocabulary, here 30k tokens\n",
|
| 283 |
+
"#data_dirs = [\"adl-piano-midi\", \"maestro-v3.0.0\", \"musicnet_midis\" ] # for single \n",
|
| 284 |
+
"data_dirs = [\"MIDIs\"]\n",
|
| 285 |
"midi_paths = []\n",
|
| 286 |
"for data_dir in data_dirs:\n",
|
| 287 |
" path = Path(root_data_dir / 'Traning Data' / data_dir)\n",
|
| 288 |
" midi_paths.extend(list(path.resolve().glob(\"**/*.mid\")) + list(path.resolve().glob(\"**/*.midi\")))\n",
|
| 289 |
"\n",
|
| 290 |
+
"print(f\"Found {len(midi_paths)} MIDI files\")\n",
|
| 291 |
+
"\n",
|
| 292 |
+
"shuffle(midi_paths)\n",
|
| 293 |
+
"\n",
|
| 294 |
+
"# We need a subset of files otherwise training tokenizer takes too long\n",
|
| 295 |
+
"percentage_to_select = 0.15\n",
|
| 296 |
+
"num_files_to_select = int(len(midi_paths) * percentage_to_select)\n",
|
| 297 |
+
"\n",
|
| 298 |
+
"subset_midi_paths = sample(midi_paths, num_files_to_select)\n",
|
| 299 |
+
"print(f\"Found {len(subset_midi_paths)} MIDI files\")"
|
| 300 |
]
|
| 301 |
},
|
| 302 |
{
|
|
|
|
| 307 |
"source": [
|
| 308 |
"#Note the size of the dataset is quite large, so it requires a huge amount of memory to train the tokenizer for 61749 files it took 64gb of memory\n",
|
| 309 |
"tokenizer.train(\n",
|
| 310 |
+
" vocab_size=24000,\n",
|
| 311 |
+
" files_paths=subset_midi_paths,\n",
|
| 312 |
")\n",
|
| 313 |
"tokenizer.save(root_save / tokenizer_name)\n",
|
| 314 |
"\n"
|
|
|
|
| 316 |
},
|
| 317 |
{
|
| 318 |
"cell_type": "code",
|
| 319 |
+
"execution_count": 8,
|
| 320 |
"metadata": {},
|
| 321 |
"outputs": [],
|
| 322 |
"source": [
|
| 323 |
+
"tokenizer = REMI(params=Path(root_save / tokenizer_name))\n"
|
| 324 |
]
|
| 325 |
},
|
| 326 |
{
|
|
|
|
| 335 |
},
|
| 336 |
{
|
| 337 |
"cell_type": "code",
|
| 338 |
+
"execution_count": 9,
|
| 339 |
"metadata": {},
|
| 340 |
"outputs": [],
|
| 341 |
"source": [
|
| 342 |
"sequence_length = 1024 # The maximum sequence length for data samples.\n",
|
| 343 |
+
"kwargs_dataset = {\"max_seq_len\": sequence_length, \"tokenizer\": tokenizer, \"bos_token_id\": tokenizer[\"BOS_None\"], \"eos_token_id\": tokenizer[\"EOS_None\"], \"pre_tokenize\": True, \"pre_tokenize_thread_count\": 7}"
|
| 344 |
+
]
|
| 345 |
+
},
|
| 346 |
+
{
|
| 347 |
+
"cell_type": "markdown",
|
| 348 |
+
"metadata": {},
|
| 349 |
+
"source": [
|
| 350 |
+
"# Test splitting files for training and testing purposes"
|
| 351 |
]
|
| 352 |
},
|
| 353 |
{
|
|
|
|
| 356 |
"metadata": {},
|
| 357 |
"outputs": [],
|
| 358 |
"source": [
|
| 359 |
+
"from pathlib import Path\n",
|
| 360 |
+
"# Split will need to add the BPM to the files its split\n",
|
| 361 |
+
"# \n",
|
| 362 |
+
"file_paths_test = [\n",
|
| 363 |
+
" Path('/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Fatboy Slim/Right Here, Right Now.mid'),\n",
|
| 364 |
+
" Path('/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Fatboy Slim/Praise You.mid'),\n",
|
| 365 |
+
" Path('/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Goo Goo Dolls/Iris.mid'),\n",
|
| 366 |
+
" Path('/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Goo Goo Dolls/Slide.mid'),\n",
|
| 367 |
+
" Path('/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/James Brown/Sex Machine (Get Up I Feel Like Being A).mid'),\n",
|
| 368 |
+
" Path('/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Jamiroquai/Virtual Insanity.1.mid'),\n",
|
| 369 |
+
" Path('/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Jamiroquai/Virtual Insanity.mid')\n",
|
| 370 |
+
"]\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
"\n",
|
| 372 |
+
"split_files_for_training(\n",
|
| 373 |
+
" files_paths=file_paths_test,\n",
|
| 374 |
+
" tokenizer=tokenizer,\n",
|
| 375 |
+
" save_dir=Path('/home/wombat/Documents/projects/music/midiTok/data/HuggingFace_Mistral_Transformer_Single_Instrument/test'),\n",
|
| 376 |
+
" max_seq_len=sequence_length,\n",
|
| 377 |
+
" num_overlap_bars=2,\n",
|
| 378 |
+
" skip_drums=True\n",
|
| 379 |
+
")"
|
| 380 |
]
|
| 381 |
},
|
| 382 |
{
|
|
|
|
| 387 |
"source": [
|
| 388 |
"# Split MIDI paths in train/valid/test sets\n",
|
| 389 |
"total_num_files = len(midi_paths)\n",
|
| 390 |
+
"\n",
|
| 391 |
"num_files_valid = round(total_num_files * 0.15)\n",
|
| 392 |
"num_files_test = round(total_num_files * 0.15)\n",
|
| 393 |
"shuffle(midi_paths)\n",
|
|
|
|
| 411 |
" save_dir=subset_chunks_dir,\n",
|
| 412 |
" max_seq_len=sequence_length,\n",
|
| 413 |
" num_overlap_bars=2,\n",
|
| 414 |
+
" skip_drums=True\n",
|
| 415 |
" )\n",
|
| 416 |
"\n",
|
| 417 |
" if subset_name == 'train':\n",
|
|
|
|
| 425 |
" )\n"
|
| 426 |
]
|
| 427 |
},
|
| 428 |
+
{
|
| 429 |
+
"cell_type": "code",
|
| 430 |
+
"execution_count": null,
|
| 431 |
+
"metadata": {},
|
| 432 |
+
"outputs": [],
|
| 433 |
+
"source": [
|
| 434 |
+
"#Since the datasets are too large after splitting we only want 50% of the split data to train against\n",
|
| 435 |
+
"sample_subset_per = .25"
|
| 436 |
+
]
|
| 437 |
+
},
|
| 438 |
{
|
| 439 |
"cell_type": "code",
|
| 440 |
"execution_count": null,
|
|
|
|
| 443 |
"source": [
|
| 444 |
"# Create Dataset and Collator for training\n",
|
| 445 |
"midi_paths_train = list(root_save.joinpath(Path(\"Maestro_train\")).glob(\"**/*.mid\")) + list(root_save.joinpath(Path(\"Maestro_train\")).glob(\"**/*.midi\"))\n",
|
| 446 |
+
"sample_count = round(len(midi_paths_train)*sample_subset_per)\n",
|
| 447 |
+
"print(f\"sample count length: {sample_count} total count: {len(midi_paths_train)}\")\n",
|
| 448 |
"\n",
|
| 449 |
+
"midi_paths_train_sample = midi_paths_train[0:sample_count]\n",
|
| 450 |
+
"print(len(midi_paths_train_sample))\n",
|
| 451 |
+
"dataset_train = DatasetMIDI(midi_paths_train_sample, **kwargs_dataset)\n",
|
| 452 |
+
"torch.save(dataset_train, Path(dataset_dir / \"dataset_train.pt\"))"
|
|
|
|
|
|
|
| 453 |
]
|
| 454 |
},
|
| 455 |
{
|
| 456 |
+
"cell_type": "code",
|
| 457 |
+
"execution_count": null,
|
| 458 |
"metadata": {},
|
| 459 |
+
"outputs": [],
|
| 460 |
"source": [
|
| 461 |
+
"midi_paths_valid = list(root_save.joinpath(Path(\"Maestro_valid\")).glob(\"**/*.mid\")) + list(root_save.joinpath(Path(\"Maestro_valid\")).glob(\"**/*.midi\")) \n",
|
| 462 |
+
"midi_paths_valid = midi_paths_valid[:(len(midi_paths_valid)*sample_subset_per)]\n",
|
| 463 |
+
"print(len(midi_paths_valid))\n",
|
| 464 |
+
"dataset_valid = DatasetMIDI(midi_paths_valid, **kwargs_dataset)\n",
|
| 465 |
+
"torch.save(dataset_valid, Path(dataset_dir / \"dataset_valid.pt\"))"
|
| 466 |
]
|
| 467 |
},
|
| 468 |
{
|
| 469 |
"cell_type": "code",
|
| 470 |
"execution_count": null,
|
| 471 |
+
"metadata": {},
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
"outputs": [],
|
| 473 |
"source": [
|
| 474 |
+
"midi_paths_test = list(root_save.joinpath(Path(\"Maestro_test\")).glob(\"**/*.mid\")) + list(root_save.joinpath(Path(\"Maestro_test\")).glob(\"**/*.midi\"))\n",
|
| 475 |
+
"midi_paths_test = midi_paths_test[:(len(midi_paths_test)*sample_subset_per)]\n",
|
| 476 |
+
"print(len(midi_paths_test))\n",
|
| 477 |
+
"dataset_test = DatasetMIDI(midi_paths_test, **kwargs_dataset)\n",
|
| 478 |
+
"torch.save(dataset_test, Path(dataset_dir / \"dataset_test.pt\"))\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 479 |
]
|
| 480 |
},
|
| 481 |
{
|
|
|
|
| 484 |
"metadata": {},
|
| 485 |
"outputs": [],
|
| 486 |
"source": [
|
| 487 |
+
"print (len(midi_paths_train), len(midi_paths_valid), len(midi_paths_test))\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 488 |
]
|
| 489 |
},
|
| 490 |
{
|
|
|
|
| 496 |
},
|
| 497 |
{
|
| 498 |
"cell_type": "code",
|
| 499 |
+
"execution_count": 10,
|
| 500 |
"metadata": {},
|
| 501 |
"outputs": [],
|
| 502 |
"source": [
|
| 503 |
+
"\n",
|
| 504 |
+
"dataset_train = torch.load(Path(dataset_dir / \"dataset_train.pt\"), weights_only=False)\n",
|
| 505 |
+
"dataset_valid = torch.load(Path(dataset_dir / \"dataset_valid.pt\"), weights_only=False)\n",
|
| 506 |
+
"dataset_test = torch.load(Path(dataset_dir / \"dataset_test.pt\"), weights_only=False)\n",
|
| 507 |
+
"\n"
|
| 508 |
]
|
| 509 |
},
|
| 510 |
{
|
|
|
|
| 513 |
"metadata": {},
|
| 514 |
"outputs": [],
|
| 515 |
"source": [
|
| 516 |
+
"import pickle\n",
|
| 517 |
+
"\n",
|
| 518 |
+
"test_file = open(Path(dataset_dir / \"dataset_test.pickle\"), 'ab')\n",
|
| 519 |
+
"pickle.dump(dataset_test, test_file)\n",
|
| 520 |
+
"test_file.close()\n",
|
| 521 |
+
"\n",
|
| 522 |
+
"print(dataset_test[0])\n",
|
| 523 |
+
"\n",
|
| 524 |
+
"test_file = open(Path(dataset_dir / \"dataset_test.pickle\"), 'rb')\n",
|
| 525 |
+
"test_pickle = pickle.load(test_file)\n",
|
| 526 |
+
"print(test_pickle)\n",
|
| 527 |
+
"print(test_pickle[0])\n",
|
| 528 |
+
"\n",
|
| 529 |
+
"\n"
|
| 530 |
]
|
| 531 |
},
|
| 532 |
{
|
| 533 |
+
"cell_type": "markdown",
|
|
|
|
| 534 |
"metadata": {},
|
|
|
|
| 535 |
"source": [
|
| 536 |
+
"# Preview files data load and split"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 537 |
]
|
| 538 |
},
|
| 539 |
{
|
|
|
|
| 542 |
"metadata": {},
|
| 543 |
"outputs": [],
|
| 544 |
"source": [
|
| 545 |
+
"\n",
|
| 546 |
+
"#testing_files = \n",
|
| 547 |
+
"preview_files_path = []\n",
|
| 548 |
+
"for testing_file in testing_files:\n",
|
| 549 |
+
" preview_files_path.append(Path(testing_file))\n",
|
| 550 |
+
"\n",
|
| 551 |
+
"preview_dir = Path(root_save / \"preview\")\n",
|
| 552 |
+
"split_files_for_training(\n",
|
| 553 |
+
" files_paths=preview_files_path,\n",
|
| 554 |
+
" tokenizer=tokenizer,\n",
|
| 555 |
+
" save_dir=preview_dir,\n",
|
| 556 |
+
" max_seq_len=sequence_length,\n",
|
| 557 |
+
" num_overlap_bars=2,\n",
|
| 558 |
+
" )\n",
|
| 559 |
+
"\n",
|
| 560 |
+
"valid_midi_path = root_save / \"Maestro_valid\"\n",
|
| 561 |
+
"midi_split_preview = list(valid_midi_path.resolve().glob(\"**/*.mid\")) + list(valid_midi_path.resolve().glob(\"**/*.midi\"))\n",
|
| 562 |
+
"\n",
|
| 563 |
+
"print(len(midi_split_preview))\n",
|
| 564 |
+
"file_name_lookup = []\n",
|
| 565 |
+
"def func_to_get_labels(p1, p2, p3):\n",
|
| 566 |
+
" if p3.name not in file_name_lookup:\n",
|
| 567 |
+
" file_name_lookup.append(p3.name)\n",
|
| 568 |
+
" return file_name_lookup.index(p3.name)\n",
|
| 569 |
+
" \n",
|
| 570 |
+
"kwargs_dataset = {\"max_seq_len\": sequence_length, \"tokenizer\": tokenizer, \"bos_token_id\": tokenizer[\"BOS_None\"], \"eos_token_id\": tokenizer[\"EOS_None\"], \"func_to_get_labels\" : func_to_get_labels}\n",
|
| 571 |
+
"dataset_preview = DatasetMIDI(midi_split_preview, **kwargs_dataset)"
|
| 572 |
]
|
| 573 |
},
|
| 574 |
{
|
|
|
|
| 586 |
},
|
| 587 |
{
|
| 588 |
"cell_type": "code",
|
| 589 |
+
"execution_count": 19,
|
| 590 |
"metadata": {},
|
| 591 |
+
"outputs": [
|
| 592 |
+
{
|
| 593 |
+
"name": "stderr",
|
| 594 |
+
"output_type": "stream",
|
| 595 |
+
"text": [
|
| 596 |
+
"Generate config GenerationConfig {\n",
|
| 597 |
+
" \"bos_token_id\": 1,\n",
|
| 598 |
+
" \"eos_token_id\": 2,\n",
|
| 599 |
+
" \"pad_token_id\": 0\n",
|
| 600 |
+
"}\n",
|
| 601 |
+
"\n"
|
| 602 |
+
]
|
| 603 |
+
}
|
| 604 |
+
],
|
| 605 |
"source": [
|
| 606 |
"# Creates model\n",
|
| 607 |
"model_config = MistralConfig(\n",
|
|
|
|
| 612 |
" num_attention_heads=8, # default 32\n",
|
| 613 |
" num_key_value_heads=4, # default 8\n",
|
| 614 |
" sliding_window=256, # default 4096\n",
|
| 615 |
+
" max_position_embeddings=8192, #has no effect on the parms count or training just limits the input length # default 4096*32\n",
|
| 616 |
" pad_token_id=tokenizer['PAD_None'],\n",
|
| 617 |
" bos_token_id=tokenizer['BOS_None'],\n",
|
| 618 |
" eos_token_id=tokenizer['EOS_None'],\n",
|
| 619 |
")\n",
|
| 620 |
+
"\n",
|
| 621 |
+
"model = AutoModelForCausalLM.from_config(model_config)\n"
|
| 622 |
+
]
|
| 623 |
+
},
|
| 624 |
+
{
|
| 625 |
+
"cell_type": "code",
|
| 626 |
+
"execution_count": 22,
|
| 627 |
+
"metadata": {},
|
| 628 |
+
"outputs": [
|
| 629 |
+
{
|
| 630 |
+
"name": "stderr",
|
| 631 |
+
"output_type": "stream",
|
| 632 |
+
"text": [
|
| 633 |
+
"loading configuration file /workspace/traindata/train/checkpoint-22000/config.json\n"
|
| 634 |
+
]
|
| 635 |
+
},
|
| 636 |
+
{
|
| 637 |
+
"name": "stderr",
|
| 638 |
+
"output_type": "stream",
|
| 639 |
+
"text": [
|
| 640 |
+
"Model config MistralConfig {\n",
|
| 641 |
+
" \"architectures\": [\n",
|
| 642 |
+
" \"MistralForCausalLM\"\n",
|
| 643 |
+
" ],\n",
|
| 644 |
+
" \"attention_dropout\": 0.0,\n",
|
| 645 |
+
" \"bos_token_id\": 1,\n",
|
| 646 |
+
" \"dtype\": \"float32\",\n",
|
| 647 |
+
" \"eos_token_id\": 2,\n",
|
| 648 |
+
" \"head_dim\": null,\n",
|
| 649 |
+
" \"hidden_act\": \"silu\",\n",
|
| 650 |
+
" \"hidden_size\": 512,\n",
|
| 651 |
+
" \"initializer_range\": 0.02,\n",
|
| 652 |
+
" \"intermediate_size\": 2048,\n",
|
| 653 |
+
" \"max_position_embeddings\": 8192,\n",
|
| 654 |
+
" \"model_type\": \"mistral\",\n",
|
| 655 |
+
" \"num_attention_heads\": 8,\n",
|
| 656 |
+
" \"num_hidden_layers\": 8,\n",
|
| 657 |
+
" \"num_key_value_heads\": 4,\n",
|
| 658 |
+
" \"pad_token_id\": 0,\n",
|
| 659 |
+
" \"rms_norm_eps\": 1e-06,\n",
|
| 660 |
+
" \"rope_theta\": 10000.0,\n",
|
| 661 |
+
" \"sliding_window\": 256,\n",
|
| 662 |
+
" \"tie_word_embeddings\": false,\n",
|
| 663 |
+
" \"transformers_version\": \"4.56.2\",\n",
|
| 664 |
+
" \"use_cache\": true,\n",
|
| 665 |
+
" \"vocab_size\": 24000\n",
|
| 666 |
+
"}\n",
|
| 667 |
+
"\n",
|
| 668 |
+
"loading weights file /workspace/traindata/train/checkpoint-22000/model.safetensors\n",
|
| 669 |
+
"Generate config GenerationConfig {\n",
|
| 670 |
+
" \"bos_token_id\": 1,\n",
|
| 671 |
+
" \"eos_token_id\": 2,\n",
|
| 672 |
+
" \"pad_token_id\": 0\n",
|
| 673 |
+
"}\n",
|
| 674 |
+
"\n",
|
| 675 |
+
"All model checkpoint weights were used when initializing MistralForCausalLM.\n",
|
| 676 |
+
"\n",
|
| 677 |
+
"All the weights of MistralForCausalLM were initialized from the model checkpoint at /workspace/traindata/train/checkpoint-22000/model.safetensors.\n",
|
| 678 |
+
"If your task is similar to the task the model of the checkpoint was trained on, you can already use MistralForCausalLM for predictions without further training.\n",
|
| 679 |
+
"Generation config file not found, using a generation config created from the model config.\n"
|
| 680 |
+
]
|
| 681 |
+
}
|
| 682 |
+
],
|
| 683 |
+
"source": [
|
| 684 |
+
"# This is only for training existing models not new ones\n",
|
| 685 |
+
"model_dir = Path(\"/workspace/traindata/train/checkpoint-22000\")\n",
|
| 686 |
+
"\n",
|
| 687 |
+
"config = AutoConfig.from_pretrained(str(model_dir / \"config.json\"))\n",
|
| 688 |
+
"model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=str(model_dir / \"model.safetensors\"), from_tf=False, config=config)"
|
| 689 |
]
|
| 690 |
},
|
| 691 |
{
|
|
|
|
| 698 |
},
|
| 699 |
{
|
| 700 |
"cell_type": "code",
|
| 701 |
+
"execution_count": 12,
|
| 702 |
"metadata": {},
|
| 703 |
+
"outputs": [
|
| 704 |
+
{
|
| 705 |
+
"name": "stdout",
|
| 706 |
+
"output_type": "stream",
|
| 707 |
+
"text": [
|
| 708 |
+
"/workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run\n"
|
| 709 |
+
]
|
| 710 |
+
}
|
| 711 |
+
],
|
| 712 |
"source": [
|
| 713 |
"model_dir = root_save / 'run'\n",
|
| 714 |
"model_dir_str = str(model_dir)\n",
|
|
|
|
| 719 |
"cell_type": "code",
|
| 720 |
"execution_count": null,
|
| 721 |
"metadata": {},
|
| 722 |
+
"outputs": [
|
| 723 |
+
{
|
| 724 |
+
"name": "stderr",
|
| 725 |
+
"output_type": "stream",
|
| 726 |
+
"text": [
|
| 727 |
+
"PyTorch: setting up devices\n",
|
| 728 |
+
"average_tokens_across_devices is True but world size is 1. Setting it to False automatically.\n",
|
| 729 |
+
"max_steps is given, it will override any value given in num_train_epochs\n",
|
| 730 |
+
"Using auto half precision backend\n"
|
| 731 |
+
]
|
| 732 |
+
},
|
| 733 |
+
{
|
| 734 |
+
"name": "stdout",
|
| 735 |
+
"output_type": "stream",
|
| 736 |
+
"text": [
|
| 737 |
+
"True\n"
|
| 738 |
+
]
|
| 739 |
+
}
|
| 740 |
+
],
|
| 741 |
"source": [
|
| 742 |
"metrics = {metric: load_metric(metric) for metric in [\"accuracy\"]}\n",
|
| 743 |
"\n",
|
|
|
|
| 778 |
"USE_MPS = not USE_CUDA and mps_available()\n",
|
| 779 |
"training_config = TrainingArguments(\n",
|
| 780 |
" model_dir_str, False, True, True, False, \"steps\",\n",
|
| 781 |
+
" per_device_train_batch_size=32, #76% @ 24 batch size #76% @ 32 batch size try 64 batch size next time \n",
|
| 782 |
+
" per_device_eval_batch_size=32, #was 24 now 32\n",
|
| 783 |
" gradient_accumulation_steps=3, #change this to 4\n",
|
| 784 |
" eval_accumulation_steps=None,\n",
|
| 785 |
+
" eval_steps=3000,\n",
|
| 786 |
+
" eval_delay=6000,\n",
|
| 787 |
" learning_rate=1e-4,\n",
|
| 788 |
" weight_decay=0.01,\n",
|
| 789 |
+
" max_grad_norm=1.0,\n",
|
| 790 |
+
" max_steps=30000,\n",
|
| 791 |
+
" lr_scheduler_type=\"cosine\",\n",
|
| 792 |
+
" warmup_ratio=0.08,\n",
|
| 793 |
" log_level=\"debug\",\n",
|
| 794 |
" logging_strategy=\"steps\",\n",
|
| 795 |
+
" logging_steps=100,\n",
|
| 796 |
" save_strategy=\"steps\",\n",
|
| 797 |
+
" save_steps=3000,\n",
|
| 798 |
" save_total_limit=5,\n",
|
| 799 |
" no_cuda=not USE_CUDA,\n",
|
| 800 |
" seed=444,\n",
|
|
|
|
| 805 |
" load_best_model_at_end=True,\n",
|
| 806 |
" label_smoothing_factor=0.,\n",
|
| 807 |
" optim=\"adamw_torch\",\n",
|
| 808 |
+
" report_to=[\"tensorboard\", \"wandb\"],\n",
|
| 809 |
+
" gradient_checkpointing=False,\n",
|
| 810 |
" dataloader_num_workers=8, #added to fix trashing isssue with the gpu not having enough data to process\n",
|
| 811 |
" dataloader_pin_memory=True, #we want the dataset in memory\n",
|
| 812 |
+
" torch_compile=False #added to speed up \n",
|
| 813 |
" \n",
|
| 814 |
")\n",
|
| 815 |
"\n",
|
|
|
|
| 834 |
"metadata": {},
|
| 835 |
"outputs": [],
|
| 836 |
"source": [
|
| 837 |
+
"torch.cuda.empty_cache()"
|
| 838 |
+
]
|
| 839 |
+
},
|
| 840 |
+
{
|
| 841 |
+
"cell_type": "code",
|
| 842 |
+
"execution_count": null,
|
| 843 |
+
"metadata": {},
|
| 844 |
+
"outputs": [],
|
| 845 |
+
"source": [
|
| 846 |
+
"print(model)\n",
|
| 847 |
+
"os.environ['CUDA_LAUNCH_BLOCKING']=\"1\"\n",
|
| 848 |
+
"os.environ['TORCH_USE_CUDA_DSA'] = \"1\""
|
| 849 |
]
|
| 850 |
},
|
| 851 |
{
|
|
|
|
| 853 |
"execution_count": null,
|
| 854 |
"metadata": {},
|
| 855 |
"outputs": [],
|
| 856 |
+
"source": [
|
| 857 |
+
"%env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
|
| 858 |
+
]
|
| 859 |
+
},
|
| 860 |
+
{
|
| 861 |
+
"cell_type": "code",
|
| 862 |
+
"execution_count": null,
|
| 863 |
+
"metadata": {},
|
| 864 |
+
"outputs": [],
|
| 865 |
+
"source": [
|
| 866 |
+
"seed=1\n",
|
| 867 |
+
"# # print(max(dataset_train[\"input_ids\"].max().item(), 0))\n",
|
| 868 |
+
"\n",
|
| 869 |
+
"torch.manual_seed(seed)\n",
|
| 870 |
+
"\n"
|
| 871 |
+
]
|
| 872 |
+
},
|
| 873 |
+
{
|
| 874 |
+
"cell_type": "code",
|
| 875 |
+
"execution_count": 24,
|
| 876 |
+
"metadata": {},
|
| 877 |
+
"outputs": [
|
| 878 |
+
{
|
| 879 |
+
"name": "stderr",
|
| 880 |
+
"output_type": "stream",
|
| 881 |
+
"text": [
|
| 882 |
+
"Currently training with a batch size of: 32\n",
|
| 883 |
+
"***** Running training *****\n",
|
| 884 |
+
" Num examples = 5,570,752\n",
|
| 885 |
+
" Num Epochs = 1\n",
|
| 886 |
+
" Instantaneous batch size per device = 32\n",
|
| 887 |
+
" Total train batch size (w. parallel, distributed & accumulation) = 96\n",
|
| 888 |
+
" Gradient Accumulation steps = 3\n",
|
| 889 |
+
" Total optimization steps = 30,000\n",
|
| 890 |
+
" Number of trainable parameters = 56,041,984\n",
|
| 891 |
+
"Automatic Weights & Biases logging enabled, to disable set os.environ[\"WANDB_DISABLED\"] = \"true\"\n",
|
| 892 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.\n"
|
| 893 |
+
]
|
| 894 |
+
},
|
| 895 |
+
{
|
| 896 |
+
"data": {
|
| 897 |
+
"text/html": [
|
| 898 |
+
"\n",
|
| 899 |
+
" <div>\n",
|
| 900 |
+
" \n",
|
| 901 |
+
" <progress value='15166' max='30000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
| 902 |
+
" [15166/30000 3:11:26 < 3:07:16, 1.32 it/s, Epoch 0.26/1]\n",
|
| 903 |
+
" </div>\n",
|
| 904 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
| 905 |
+
" <thead>\n",
|
| 906 |
+
" <tr style=\"text-align: left;\">\n",
|
| 907 |
+
" <th>Step</th>\n",
|
| 908 |
+
" <th>Training Loss</th>\n",
|
| 909 |
+
" <th>Validation Loss</th>\n",
|
| 910 |
+
" <th>Accuracy</th>\n",
|
| 911 |
+
" </tr>\n",
|
| 912 |
+
" </thead>\n",
|
| 913 |
+
" <tbody>\n",
|
| 914 |
+
" <tr>\n",
|
| 915 |
+
" <td>6000</td>\n",
|
| 916 |
+
" <td>1.602200</td>\n",
|
| 917 |
+
" <td>1.751858</td>\n",
|
| 918 |
+
" <td>0.010508</td>\n",
|
| 919 |
+
" </tr>\n",
|
| 920 |
+
" <tr>\n",
|
| 921 |
+
" <td>9000</td>\n",
|
| 922 |
+
" <td>1.584300</td>\n",
|
| 923 |
+
" <td>1.732275</td>\n",
|
| 924 |
+
" <td>0.010426</td>\n",
|
| 925 |
+
" </tr>\n",
|
| 926 |
+
" <tr>\n",
|
| 927 |
+
" <td>12000</td>\n",
|
| 928 |
+
" <td>1.547300</td>\n",
|
| 929 |
+
" <td>1.712772</td>\n",
|
| 930 |
+
" <td>0.010505</td>\n",
|
| 931 |
+
" </tr>\n",
|
| 932 |
+
" <tr>\n",
|
| 933 |
+
" <td>15000</td>\n",
|
| 934 |
+
" <td>1.540700</td>\n",
|
| 935 |
+
" <td>1.694235</td>\n",
|
| 936 |
+
" <td>0.010407</td>\n",
|
| 937 |
+
" </tr>\n",
|
| 938 |
+
" </tbody>\n",
|
| 939 |
+
"</table><p>"
|
| 940 |
+
],
|
| 941 |
+
"text/plain": [
|
| 942 |
+
"<IPython.core.display.HTML object>"
|
| 943 |
+
]
|
| 944 |
+
},
|
| 945 |
+
"metadata": {},
|
| 946 |
+
"output_type": "display_data"
|
| 947 |
+
},
|
| 948 |
+
{
|
| 949 |
+
"name": "stderr",
|
| 950 |
+
"output_type": "stream",
|
| 951 |
+
"text": [
|
| 952 |
+
"Saving model checkpoint to /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-3000\n",
|
| 953 |
+
"Configuration saved in /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-3000/config.json\n",
|
| 954 |
+
"Configuration saved in /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-3000/generation_config.json\n",
|
| 955 |
+
"Model weights saved in /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-3000/model.safetensors\n",
|
| 956 |
+
"\n",
|
| 957 |
+
"***** Running Evaluation *****\n",
|
| 958 |
+
" Num examples = 849907\n",
|
| 959 |
+
" Batch size = 32\n",
|
| 960 |
+
"Saving model checkpoint to /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-6000\n",
|
| 961 |
+
"Configuration saved in /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-6000/config.json\n",
|
| 962 |
+
"Configuration saved in /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-6000/generation_config.json\n",
|
| 963 |
+
"Model weights saved in /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-6000/model.safetensors\n",
|
| 964 |
+
"\n",
|
| 965 |
+
"***** Running Evaluation *****\n",
|
| 966 |
+
" Num examples = 849907\n",
|
| 967 |
+
" Batch size = 32\n",
|
| 968 |
+
"Saving model checkpoint to /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-9000\n",
|
| 969 |
+
"Configuration saved in /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-9000/config.json\n",
|
| 970 |
+
"Configuration saved in /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-9000/generation_config.json\n",
|
| 971 |
+
"Model weights saved in /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-9000/model.safetensors\n",
|
| 972 |
+
"\n",
|
| 973 |
+
"***** Running Evaluation *****\n",
|
| 974 |
+
" Num examples = 849907\n",
|
| 975 |
+
" Batch size = 32\n",
|
| 976 |
+
"Saving model checkpoint to /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-12000\n",
|
| 977 |
+
"Configuration saved in /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-12000/config.json\n",
|
| 978 |
+
"Configuration saved in /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-12000/generation_config.json\n",
|
| 979 |
+
"Model weights saved in /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-12000/model.safetensors\n",
|
| 980 |
+
"\n",
|
| 981 |
+
"***** Running Evaluation *****\n",
|
| 982 |
+
" Num examples = 849907\n",
|
| 983 |
+
" Batch size = 32\n",
|
| 984 |
+
"Saving model checkpoint to /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-15000\n",
|
| 985 |
+
"Configuration saved in /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-15000/config.json\n",
|
| 986 |
+
"Configuration saved in /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-15000/generation_config.json\n",
|
| 987 |
+
"Model weights saved in /workspace/traindata/data/HuggingFace_Mistral_Transformer_Single_Instrument/run/checkpoint-15000/model.safetensors\n"
|
| 988 |
+
]
|
| 989 |
+
},
|
| 990 |
+
{
|
| 991 |
+
"ename": "KeyboardInterrupt",
|
| 992 |
+
"evalue": "",
|
| 993 |
+
"output_type": "error",
|
| 994 |
+
"traceback": [
|
| 995 |
+
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
| 996 |
+
"\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)",
|
| 997 |
+
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[24]\u001b[39m\u001b[32m, line 2\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;66;03m# Training\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m2\u001b[39m train_result = \u001b[43mtrainer\u001b[49m\u001b[43m.\u001b[49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 3\u001b[39m trainer.save_model() \u001b[38;5;66;03m# Saves the tokenizer too\u001b[39;00m\n\u001b[32m 4\u001b[39m trainer.log_metrics(\u001b[33m\"\u001b[39m\u001b[33mtrain\u001b[39m\u001b[33m\"\u001b[39m, train_result.metrics)\n",
|
| 998 |
+
"\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/transformers/trainer.py:2328\u001b[39m, in \u001b[36mTrainer.train\u001b[39m\u001b[34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[39m\n\u001b[32m 2326\u001b[39m hf_hub_utils.enable_progress_bars()\n\u001b[32m 2327\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m2328\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 2329\u001b[39m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m=\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2330\u001b[39m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m=\u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2331\u001b[39m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2332\u001b[39m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m=\u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2333\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
|
| 999 |
+
"\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/transformers/trainer.py:2672\u001b[39m, in \u001b[36mTrainer._inner_training_loop\u001b[39m\u001b[34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[39m\n\u001b[32m 2665\u001b[39m context = (\n\u001b[32m 2666\u001b[39m functools.partial(\u001b[38;5;28mself\u001b[39m.accelerator.no_sync, model=model)\n\u001b[32m 2667\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m i != \u001b[38;5;28mlen\u001b[39m(batch_samples) - \u001b[32m1\u001b[39m\n\u001b[32m 2668\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m.accelerator.distributed_type != DistributedType.DEEPSPEED\n\u001b[32m 2669\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m contextlib.nullcontext\n\u001b[32m 2670\u001b[39m )\n\u001b[32m 2671\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m context():\n\u001b[32m-> \u001b[39m\u001b[32m2672\u001b[39m tr_loss_step = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_items_in_batch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 2674\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[32m 2675\u001b[39m args.logging_nan_inf_filter\n\u001b[32m 2676\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_xla_available()\n\u001b[32m 2677\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m (torch.isnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch.isinf(tr_loss_step))\n\u001b[32m 2678\u001b[39m ):\n\u001b[32m 2679\u001b[39m \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[32m 2680\u001b[39m tr_loss = tr_loss + tr_loss / (\u001b[32m1\u001b[39m + \u001b[38;5;28mself\u001b[39m.state.global_step - \u001b[38;5;28mself\u001b[39m._globalstep_last_logged)\n",
|
| 1000 |
+
"\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/transformers/trainer.py:4060\u001b[39m, in \u001b[36mTrainer.training_step\u001b[39m\u001b[34m(***failed resolving arguments***)\u001b[39m\n\u001b[32m 4057\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.accelerator.distributed_type == DistributedType.DEEPSPEED:\n\u001b[32m 4058\u001b[39m kwargs[\u001b[33m\"\u001b[39m\u001b[33mscale_wrt_gas\u001b[39m\u001b[33m\"\u001b[39m] = \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m4060\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43maccelerator\u001b[49m\u001b[43m.\u001b[49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloss\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 4062\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m loss.detach()\n",
|
| 1001 |
+
"\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/accelerate/accelerator.py:2734\u001b[39m, in \u001b[36mAccelerator.backward\u001b[39m\u001b[34m(self, loss, **kwargs)\u001b[39m\n\u001b[32m 2732\u001b[39m \u001b[38;5;28mself\u001b[39m.lomo_backward(loss, learning_rate)\n\u001b[32m 2733\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m2734\u001b[39m \u001b[43mloss\u001b[49m\u001b[43m.\u001b[49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 1002 |
+
"\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/torch/_tensor.py:648\u001b[39m, in \u001b[36mTensor.backward\u001b[39m\u001b[34m(self, gradient, retain_graph, create_graph, inputs)\u001b[39m\n\u001b[32m 638\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[32m 639\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[32m 640\u001b[39m Tensor.backward,\n\u001b[32m 641\u001b[39m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[32m (...)\u001b[39m\u001b[32m 646\u001b[39m inputs=inputs,\n\u001b[32m 647\u001b[39m )\n\u001b[32m--> \u001b[39m\u001b[32m648\u001b[39m \u001b[43mtorch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mautograd\u001b[49m\u001b[43m.\u001b[49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 649\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m=\u001b[49m\u001b[43minputs\u001b[49m\n\u001b[32m 650\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 1003 |
+
"\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/torch/autograd/__init__.py:353\u001b[39m, in \u001b[36mbackward\u001b[39m\u001b[34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[39m\n\u001b[32m 348\u001b[39m retain_graph = create_graph\n\u001b[32m 350\u001b[39m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[32m 351\u001b[39m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[32m 352\u001b[39m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m353\u001b[39m \u001b[43m_engine_run_backward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 354\u001b[39m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 355\u001b[39m \u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 356\u001b[39m \u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 357\u001b[39m \u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 358\u001b[39m \u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 359\u001b[39m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[32m 360\u001b[39m \u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[32m 361\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 1004 |
+
"\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/torch/autograd/graph.py:824\u001b[39m, in \u001b[36m_engine_run_backward\u001b[39m\u001b[34m(t_outputs, *args, **kwargs)\u001b[39m\n\u001b[32m 822\u001b[39m unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)\n\u001b[32m 823\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m824\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mVariable\u001b[49m\u001b[43m.\u001b[49m\u001b[43m_execution_engine\u001b[49m\u001b[43m.\u001b[49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[32m 825\u001b[39m \u001b[43m \u001b[49m\u001b[43mt_outputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\n\u001b[32m 826\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001b[39;00m\n\u001b[32m 827\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[32m 828\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m attach_logging_hooks:\n",
|
| 1005 |
+
"\u001b[31mKeyboardInterrupt\u001b[39m: "
|
| 1006 |
+
]
|
| 1007 |
+
}
|
| 1008 |
+
],
|
| 1009 |
"source": [
|
| 1010 |
"# Training\n",
|
| 1011 |
"train_result = trainer.train()\n",
|
|
|
|
| 1015 |
"trainer.save_state()"
|
| 1016 |
]
|
| 1017 |
},
|
| 1018 |
+
{
|
| 1019 |
+
"cell_type": "code",
|
| 1020 |
+
"execution_count": 25,
|
| 1021 |
+
"metadata": {},
|
| 1022 |
+
"outputs": [
|
| 1023 |
+
{
|
| 1024 |
+
"data": {
|
| 1025 |
+
"text/html": [],
|
| 1026 |
+
"text/plain": [
|
| 1027 |
+
"<IPython.core.display.HTML object>"
|
| 1028 |
+
]
|
| 1029 |
+
},
|
| 1030 |
+
"metadata": {},
|
| 1031 |
+
"output_type": "display_data"
|
| 1032 |
+
},
|
| 1033 |
+
{
|
| 1034 |
+
"data": {
|
| 1035 |
+
"text/html": [
|
| 1036 |
+
"<br> <style><br> .wandb-row {<br> display: flex;<br> flex-direction: row;<br> flex-wrap: wrap;<br> justify-content: flex-start;<br> width: 100%;<br> }<br> .wandb-col {<br> display: flex;<br> flex-direction: column;<br> flex-basis: 100%;<br> flex: 1;<br> padding: 10px;<br> }<br> </style><br><div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>eval/accuracy</td><td>█▂█▁</td></tr><tr><td>eval/loss</td><td>█▆▃▁</td></tr><tr><td>eval/runtime</td><td>█▁▃▆</td></tr><tr><td>eval/samples_per_second</td><td>▁█▆▃</td></tr><tr><td>eval/steps_per_second</td><td>▁█▆▃</td></tr><tr><td>train/epoch</td><td>▁▁▁▁▁▂▂▂▂▁▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇███</td></tr><tr><td>train/global_step</td><td>▁▁▁▁▂▂▂▂▂▂▁▁▁▁▁▂▂▂▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇████</td></tr><tr><td>train/grad_norm</td><td>▅▃▄▅▆█▇▆▁▁▂▂▂▂▁▂▁▂▂▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>train/learning_rate</td><td>▃▅▅▆▆█▁▂▄▄▆▇█████████▇▇▇▇▇▇▇▇▇▆▆▆▆▆▆▆▆▆▅</td></tr><tr><td>train/loss</td><td>█▇▆▆▆▅▄▄▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>eval/accuracy</td><td>0.01041</td></tr><tr><td>eval/loss</td><td>1.69424</td></tr><tr><td>eval/runtime</td><td>1748.5708</td></tr><tr><td>eval/samples_per_second</td><td>486.058</td></tr><tr><td>eval/steps_per_second</td><td>15.19</td></tr><tr><td>train/epoch</td><td>0.26022</td></tr><tr><td>train/global_step</td><td>15100</td></tr><tr><td>train/grad_norm</td><td>0.56171</td></tr><tr><td>train/learning_rate</td><td>6e-05</td></tr><tr><td>train/loss</td><td>1.5401</td></tr></table><br/></div></div>"
|
| 1037 |
+
],
|
| 1038 |
+
"text/plain": [
|
| 1039 |
+
"<IPython.core.display.HTML object>"
|
| 1040 |
+
]
|
| 1041 |
+
},
|
| 1042 |
+
"metadata": {},
|
| 1043 |
+
"output_type": "display_data"
|
| 1044 |
+
},
|
| 1045 |
+
{
|
| 1046 |
+
"data": {
|
| 1047 |
+
"text/html": [
|
| 1048 |
+
" View run <strong style=\"color:#cdcd00\">fast-yogurt-6</strong> at: <a href='https://wandb.ai/adric-landman-hobby/midi_music_maker/runs/g1fn393k' target=\"_blank\">https://wandb.ai/adric-landman-hobby/midi_music_maker/runs/g1fn393k</a><br> View project at: <a href='https://wandb.ai/adric-landman-hobby/midi_music_maker' target=\"_blank\">https://wandb.ai/adric-landman-hobby/midi_music_maker</a><br>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
|
| 1049 |
+
],
|
| 1050 |
+
"text/plain": [
|
| 1051 |
+
"<IPython.core.display.HTML object>"
|
| 1052 |
+
]
|
| 1053 |
+
},
|
| 1054 |
+
"metadata": {},
|
| 1055 |
+
"output_type": "display_data"
|
| 1056 |
+
},
|
| 1057 |
+
{
|
| 1058 |
+
"data": {
|
| 1059 |
+
"text/html": [
|
| 1060 |
+
"Find logs at: <code></code>"
|
| 1061 |
+
],
|
| 1062 |
+
"text/plain": [
|
| 1063 |
+
"<IPython.core.display.HTML object>"
|
| 1064 |
+
]
|
| 1065 |
+
},
|
| 1066 |
+
"metadata": {},
|
| 1067 |
+
"output_type": "display_data"
|
| 1068 |
+
}
|
| 1069 |
+
],
|
| 1070 |
+
"source": [
|
| 1071 |
+
"wandb.finish()"
|
| 1072 |
+
]
|
| 1073 |
+
},
|
| 1074 |
{
|
| 1075 |
"cell_type": "code",
|
| 1076 |
"execution_count": null,
|
|
|
|
| 1083 |
},
|
| 1084 |
{
|
| 1085 |
"cell_type": "code",
|
| 1086 |
+
"execution_count": 26,
|
| 1087 |
"metadata": {},
|
| 1088 |
+
"outputs": [
|
| 1089 |
+
{
|
| 1090 |
+
"ename": "HfHubHTTPError",
|
| 1091 |
+
"evalue": "401 Client Error: Unauthorized for url: https://huggingface.co/api/repos/create (Request ID: Root=1-68d628b1-575691d056937c56340182c7;4f0ea033-8e43-4260-938f-d74b744942be)\n\nInvalid credentials in Authorization header",
|
| 1092 |
+
"output_type": "error",
|
| 1093 |
+
"traceback": [
|
| 1094 |
+
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
| 1095 |
+
"\u001b[31mHTTPError\u001b[39m Traceback (most recent call last)",
|
| 1096 |
+
"\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_http.py:407\u001b[39m, in \u001b[36mhf_raise_for_status\u001b[39m\u001b[34m(response, endpoint_name)\u001b[39m\n\u001b[32m 406\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m407\u001b[39m \u001b[43mresponse\u001b[49m\u001b[43m.\u001b[49m\u001b[43mraise_for_status\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 408\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m HTTPError \u001b[38;5;28;01mas\u001b[39;00m e:\n",
|
| 1097 |
+
"\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/requests/models.py:1024\u001b[39m, in \u001b[36mResponse.raise_for_status\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 1023\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m http_error_msg:\n\u001b[32m-> \u001b[39m\u001b[32m1024\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m HTTPError(http_error_msg, response=\u001b[38;5;28mself\u001b[39m)\n",
|
| 1098 |
+
"\u001b[31mHTTPError\u001b[39m: 401 Client Error: Unauthorized for url: https://huggingface.co/api/repos/create",
|
| 1099 |
+
"\nThe above exception was the direct cause of the following exception:\n",
|
| 1100 |
+
"\u001b[31mHfHubHTTPError\u001b[39m Traceback (most recent call last)",
|
| 1101 |
+
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[26]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m model.hub_model_id = \u001b[33m\"\u001b[39m\u001b[33madricl/midi_single_instrument_mistral_transformer\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m.\u001b[49m\u001b[43mpush_to_hub\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcommit_message\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mTraining Basic Model for Mistral MidiTok Transformer Single Instrument Small\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrepo_id\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43madricl/midi_single_instrument_mistral_transformer\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 4\u001b[39m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
|
| 1102 |
+
"\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py:4346\u001b[39m, in \u001b[36mPreTrainedModel.push_to_hub\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 4344\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m tags:\n\u001b[32m 4345\u001b[39m kwargs[\u001b[33m\"\u001b[39m\u001b[33mtags\u001b[39m\u001b[33m\"\u001b[39m] = tags\n\u001b[32m-> \u001b[39m\u001b[32m4346\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mpush_to_hub\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 1103 |
+
"\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/transformers/utils/hub.py:955\u001b[39m, in \u001b[36mPushToHubMixin.push_to_hub\u001b[39m\u001b[34m(self, repo_id, use_temp_dir, commit_message, private, token, max_shard_size, create_pr, safe_serialization, revision, commit_description, tags, **deprecated_kwargs)\u001b[39m\n\u001b[32m 952\u001b[39m repo_url = deprecated_kwargs.pop(\u001b[33m\"\u001b[39m\u001b[33mrepo_url\u001b[39m\u001b[33m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[32m 953\u001b[39m organization = deprecated_kwargs.pop(\u001b[33m\"\u001b[39m\u001b[33morganization\u001b[39m\u001b[33m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[32m--> \u001b[39m\u001b[32m955\u001b[39m repo_id = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_create_repo\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 956\u001b[39m \u001b[43m \u001b[49m\u001b[43mrepo_id\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprivate\u001b[49m\u001b[43m=\u001b[49m\u001b[43mprivate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrepo_url\u001b[49m\u001b[43m=\u001b[49m\u001b[43mrepo_url\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43morganization\u001b[49m\u001b[43m=\u001b[49m\u001b[43morganization\u001b[49m\n\u001b[32m 957\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 959\u001b[39m \u001b[38;5;66;03m# Create a new empty model card and eventually tag it\u001b[39;00m\n\u001b[32m 960\u001b[39m model_card = create_and_tag_model_card(\n\u001b[32m 961\u001b[39m repo_id, tags, token=token, ignore_metadata_errors=ignore_metadata_errors\n\u001b[32m 962\u001b[39m )\n",
|
| 1104 |
+
"\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/transformers/utils/hub.py:759\u001b[39m, in \u001b[36mPushToHubMixin._create_repo\u001b[39m\u001b[34m(self, repo_id, private, token, repo_url, organization)\u001b[39m\n\u001b[32m 756\u001b[39m repo_id = repo_id.split(\u001b[33m\"\u001b[39m\u001b[33m/\u001b[39m\u001b[33m\"\u001b[39m)[-\u001b[32m1\u001b[39m]\n\u001b[32m 757\u001b[39m repo_id = \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00morganization\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mrepo_id\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m\n\u001b[32m--> \u001b[39m\u001b[32m759\u001b[39m url = \u001b[43mcreate_repo\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrepo_id\u001b[49m\u001b[43m=\u001b[49m\u001b[43mrepo_id\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprivate\u001b[49m\u001b[43m=\u001b[49m\u001b[43mprivate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mexist_ok\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[32m 760\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m url.repo_id\n",
|
| 1105 |
+
"\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_validators.py:114\u001b[39m, in \u001b[36mvalidate_hf_hub_args.<locals>._inner_fn\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 111\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m check_use_auth_token:\n\u001b[32m 112\u001b[39m kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.\u001b[34m__name__\u001b[39m, has_token=has_token, kwargs=kwargs)\n\u001b[32m--> \u001b[39m\u001b[32m114\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 1106 |
+
"\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/huggingface_hub/hf_api.py:3766\u001b[39m, in \u001b[36mHfApi.create_repo\u001b[39m\u001b[34m(self, repo_id, token, private, repo_type, exist_ok, resource_group_id, space_sdk, space_hardware, space_storage, space_sleep_time, space_secrets, space_variables)\u001b[39m\n\u001b[32m 3763\u001b[39m \u001b[38;5;28;01mbreak\u001b[39;00m\n\u001b[32m 3765\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m3766\u001b[39m \u001b[43mhf_raise_for_status\u001b[49m\u001b[43m(\u001b[49m\u001b[43mr\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 3767\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m HTTPError \u001b[38;5;28;01mas\u001b[39;00m err:\n\u001b[32m 3768\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m exist_ok \u001b[38;5;129;01mand\u001b[39;00m err.response.status_code == \u001b[32m409\u001b[39m:\n\u001b[32m 3769\u001b[39m \u001b[38;5;66;03m# Repo already exists and `exist_ok=True`\u001b[39;00m\n",
|
| 1107 |
+
"\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_http.py:480\u001b[39m, in \u001b[36mhf_raise_for_status\u001b[39m\u001b[34m(response, endpoint_name)\u001b[39m\n\u001b[32m 476\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m _format(HfHubHTTPError, message, response) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01me\u001b[39;00m\n\u001b[32m 478\u001b[39m \u001b[38;5;66;03m# Convert `HTTPError` into a `HfHubHTTPError` to display request information\u001b[39;00m\n\u001b[32m 479\u001b[39m \u001b[38;5;66;03m# as well (request id and/or server error message)\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m480\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m _format(HfHubHTTPError, \u001b[38;5;28mstr\u001b[39m(e), response) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01me\u001b[39;00m\n",
|
| 1108 |
+
"\u001b[31mHfHubHTTPError\u001b[39m: 401 Client Error: Unauthorized for url: https://huggingface.co/api/repos/create (Request ID: Root=1-68d628b1-575691d056937c56340182c7;4f0ea033-8e43-4260-938f-d74b744942be)\n\nInvalid credentials in Authorization header"
|
| 1109 |
+
]
|
| 1110 |
+
}
|
| 1111 |
+
],
|
| 1112 |
"source": [
|
| 1113 |
"\n",
|
| 1114 |
"model.hub_model_id = \"adricl/midi_single_instrument_mistral_transformer\"\n",
|
|
|
|
| 1238 |
"name": "python",
|
| 1239 |
"nbconvert_exporter": "python",
|
| 1240 |
"pygments_lexer": "ipython3",
|
| 1241 |
+
"version": "3.11.11"
|
| 1242 |
},
|
| 1243 |
"vscode": {
|
| 1244 |
"interpreter": {
|
train_tokenizer.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from miditok import REMI, TokenizerConfig
|
| 2 |
+
from random import shuffle, sample
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
# Our tokenizer's configuration
|
| 6 |
+
BEAT_RES = {(0, 1): 12, (1, 2): 4, (2, 4): 2, (4, 8): 1}
|
| 7 |
+
TOKENIZER_PARAMS = {
|
| 8 |
+
"pitch_range": (21, 108),
|
| 9 |
+
"beat_res": BEAT_RES,
|
| 10 |
+
"num_velocities": 32,
|
| 11 |
+
"special_tokens": ["PAD", "BOS", "EOS"],
|
| 12 |
+
"use_chords": True,
|
| 13 |
+
"use_rests": True,
|
| 14 |
+
"use_tempos": True,
|
| 15 |
+
"use_time_signatures": True,
|
| 16 |
+
"use_programs": False, # We want single track
|
| 17 |
+
"one_token_stream_for_programs": False, # We want single track
|
| 18 |
+
"programs": list(range(0, 128)), #-1 drums, skip drums
|
| 19 |
+
"num_tempos": 32,
|
| 20 |
+
"tempo_range": (40, 250), # (min_tempo, max_tempo)
|
| 21 |
+
}
|
| 22 |
+
config = TokenizerConfig(**TOKENIZER_PARAMS)
|
| 23 |
+
|
| 24 |
+
# Creates the tokenizer REMI PLUS
|
| 25 |
+
tokenizer = REMI(config)
|
| 26 |
+
|
| 27 |
+
root_data_dir = Path('/root')
|
| 28 |
+
root_save = Path(root_data_dir / 'HuggingFace_Mistral_Transformer_Single_Instrument')
|
| 29 |
+
|
| 30 |
+
tokenizer_name = "HuggingFace_Mistral_Transformer_Single_Instrument_v4_single_track.json"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
data_dirs = ["MIDIs"]
|
| 35 |
+
midi_paths = []
|
| 36 |
+
for data_dir in data_dirs:
|
| 37 |
+
path = Path(root_data_dir / data_dir)
|
| 38 |
+
midi_paths.extend(list(path.resolve().glob("**/*.mid")) + list(path.resolve().glob("**/*.midi")))
|
| 39 |
+
|
| 40 |
+
print(f"Found {len(midi_paths)} MIDI files")
|
| 41 |
+
|
| 42 |
+
midi_paths = midi_paths.shuffle()
|
| 43 |
+
|
| 44 |
+
# We need a subset of files otherwise training tokenizer takes too long
|
| 45 |
+
percentage_to_select = 0.20
|
| 46 |
+
num_files_to_select = int(len(midi_paths) * percentage_to_select)
|
| 47 |
+
|
| 48 |
+
subset_midi_paths = sample(midi_paths, num_files_to_select)
|
| 49 |
+
print(f"Found {len(subset_midi_paths)} MIDI files")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
tokenizer.train(
|
| 53 |
+
vocab_size=24000,
|
| 54 |
+
files_paths=subset_midi_paths,
|
| 55 |
+
)
|
| 56 |
+
tokenizer.save(root_save / tokenizer_name)
|