File size: 5,483 Bytes
4ec02d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from copy import deepcopy
from pathlib import Path
from random import shuffle, sample

from evaluate import load as load_metric
from miditok import REMI, TokenizerConfig, TokTrainingIterator
from miditok.pytorch_data import DatasetMIDI, DataCollator
from miditok.utils import split_files_for_training

from miditok.data_augmentation import augment_dataset
from torch import Tensor, argmax, torch
from torch.utils.data import DataLoader
from torch.cuda import is_available as cuda_available, is_bf16_supported
from torch.backends.mps import is_available as mps_available
from transformers import AutoModelForCausalLM, MistralConfig, Trainer, TrainingArguments, GenerationConfig, AutoConfig
from transformers.trainer_utils import set_seed
from tqdm import tqdm

root_data_dir = Path('/home/wombat/Documents/projects/music/midiTok/data/')
root_save = Path(root_data_dir / 'HuggingFace_Mistral_Transformer_Single_Instrument')

tokenizer_name = "HuggingFace_Mistral_Transformer_Single_Instrument_v3_single_track.json"

tokenizer = REMI(params=Path(root_save / tokenizer_name))

sequence_length = 1024  # The maximum sequence length for data samples.
kwargs_dataset = {"max_seq_len": sequence_length, "tokenizer": tokenizer, "bos_token_id": tokenizer["BOS_None"], "eos_token_id": tokenizer["EOS_None"]}


dataset_dir = root_save / "data"
dataset_dir.mkdir(parents=True, exist_ok=True)


dataset_train = torch.load(Path(dataset_dir / "dataset_train.pt"), weights_only=False)
dataset_valid = torch.load(Path(dataset_dir / "dataset_valid.pt"), weights_only=False)
dataset_test = torch.load(Path(dataset_dir / "dataset_test.pt"), weights_only=False)

# Creates model
model_config = MistralConfig(
    vocab_size=len(tokenizer), #from miditok output default 32K
    hidden_size=512, # default 4096
    intermediate_size=2048, # default  14336
    num_hidden_layers=8, # default  32
    num_attention_heads=8, # default  32
    num_key_value_heads=4, # default 8
    sliding_window=256, # default  4096
    max_position_embeddings=8192, #has no effect on the parms count or training just limits the input length  # default 4096*32
    pad_token_id=tokenizer['PAD_None'],
    bos_token_id=tokenizer['BOS_None'],
    eos_token_id=tokenizer['EOS_None'],
)
model = AutoModelForCausalLM.from_config(model_config)

model_dir = root_save / 'run'
model_dir_str = str(model_dir)
print(model_dir)

metrics = {metric: load_metric(metric) for metric in ["accuracy"]}

def compute_metrics(eval_pred):
    """
    Compute metrics for pretraining.

    Must use preprocess_logits function that converts logits to predictions (argmax or sampling).

    :param eval_pred: EvalPrediction containing predictions and labels
    :return: metrics
    """
    predictions, labels = eval_pred
    not_pad_mask = labels != -100
    labels, predictions = labels[not_pad_mask], predictions[not_pad_mask]
    return metrics["accuracy"].compute(predictions=predictions.flatten(), references=labels.flatten())

def preprocess_logits(logits: Tensor, _: Tensor) -> Tensor:
    """
    Preprocess the logits before accumulating them during evaluation.

    This allows to significantly reduce the memory usage and make the training tractable.
    """
    pred_ids = argmax(logits, dim=-1)  # long dtype
    return pred_ids

# Create config for the Trainer
USE_CUDA = cuda_available()
print(USE_CUDA)
if not cuda_available():
    FP16 = FP16_EVAL = BF16 = BF16_EVAL = False
elif is_bf16_supported():
    BF16 = BF16_EVAL = True
    FP16 = FP16_EVAL = False
else:
    BF16 = BF16_EVAL = False
    FP16 = FP16_EVAL = True
USE_MPS = not USE_CUDA and mps_available()
training_config = TrainingArguments(
    model_dir_str, False, True, True, False, "steps",
    per_device_train_batch_size=24, #76% @ 24 batch size #76% @ 32 batch size try 64 batch size next time 
    per_device_eval_batch_size=24, #was 24 now 32
    gradient_accumulation_steps=2, #change this to 4
    eval_accumulation_steps=None,
    eval_steps=1000,
    learning_rate=1e-4,
    weight_decay=0.01,
    max_grad_norm=3.0,
    max_steps=40000,
    lr_scheduler_type="cosine_with_restarts",
    warmup_ratio=0.3,
    log_level="debug",
    logging_strategy="steps",
    logging_steps=20,
    save_strategy="steps",
    save_steps=1000,
    save_total_limit=5,
    no_cuda=not USE_CUDA,
    seed=444,
    fp16=FP16,
    fp16_full_eval=FP16_EVAL,
    bf16=BF16,
    bf16_full_eval=BF16_EVAL,
    load_best_model_at_end=True,
    label_smoothing_factor=0.,
    optim="adamw_torch",
    report_to=["tensorboard"],
    gradient_checkpointing=True,
    dataloader_num_workers=8, #added to fix trashing isssue with the gpu not having enough data to process
    dataloader_pin_memory=True, #we want the dataset in memory
    torch_compile=True #added to speed up 
    
)

collator = DataCollator(tokenizer["PAD_None"], copy_inputs_as_labels=True, pad_on_left=True) #not sure about the pad_on_left, it might get better results
trainer = Trainer(
    model=model,
    args=training_config,
    data_collator=collator,
    train_dataset=dataset_train,
    eval_dataset=dataset_valid,
    compute_metrics=compute_metrics,
    callbacks=None,
    preprocess_logits_for_metrics=preprocess_logits,
    
)



#%env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

# Training
train_result = trainer.train()
trainer.save_model()  # Saves the tokenizer too
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()