Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -75,19 +75,18 @@ dtype = 'bfloat16'
|
|
| 75 |
ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
|
| 76 |
ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
| 77 |
|
| 78 |
-
SEQ_LEN =
|
| 79 |
-
PAD_IDX =
|
| 80 |
-
|
| 81 |
-
model = TransformerWrapper(
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
)
|
| 91 |
|
| 92 |
model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
|
| 93 |
|
|
@@ -100,6 +99,9 @@ model.load_state_dict(torch.load(model_checkpoint, map_location=device_type, wei
|
|
| 100 |
|
| 101 |
model = torch.compile(model, mode='max-autotune')
|
| 102 |
|
|
|
|
|
|
|
|
|
|
| 103 |
print('=' * 70)
|
| 104 |
print('Done!')
|
| 105 |
print('=' * 70)
|
|
@@ -113,47 +115,46 @@ def load_midi(input_midi, melody_patch=-1, use_nth_note=1):
|
|
| 113 |
raw_score = TMIDIX.midi2single_track_ms_score(input_midi)
|
| 114 |
|
| 115 |
escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
|
| 116 |
-
escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes, timings_divider=32)
|
| 117 |
|
| 118 |
-
sp_escore_notes = TMIDIX.solo_piano_escore_notes(escore_notes
|
| 119 |
-
|
| 120 |
-
if melody_patch == -1:
|
| 121 |
-
zscore = TMIDIX.recalculate_score_timings(sp_escore_notes)
|
| 122 |
-
|
| 123 |
-
else:
|
| 124 |
-
mel_score = [e for e in sp_escore_notes if e[6] == melody_patch]
|
| 125 |
-
|
| 126 |
-
if mel_score:
|
| 127 |
-
zscore = TMIDIX.recalculate_score_timings(mel_score)
|
| 128 |
-
|
| 129 |
-
else:
|
| 130 |
-
zscore = TMIDIX.recalculate_score_timings(sp_escore_notes)
|
| 131 |
|
| 132 |
-
|
| 133 |
|
| 134 |
-
|
| 135 |
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
pc = cscore[0]
|
| 139 |
|
| 140 |
for c in cscore:
|
| 141 |
-
score.append(max(0, min(127, c[0][1]-pc[0][1])))
|
| 142 |
|
| 143 |
-
|
| 144 |
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
-
|
|
|
|
| 151 |
|
| 152 |
pc = c
|
| 153 |
|
| 154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
-
return score,
|
| 157 |
|
| 158 |
#==================================================================================
|
| 159 |
|
|
@@ -268,8 +269,7 @@ def Generate_Accompaniment(input_midi,
|
|
| 268 |
print('=' * 70)
|
| 269 |
print('Generating...')
|
| 270 |
|
| 271 |
-
|
| 272 |
-
model.eval()
|
| 273 |
|
| 274 |
#==================================================================
|
| 275 |
|
|
|
|
| 75 |
ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
|
| 76 |
ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
| 77 |
|
| 78 |
+
SEQ_LEN = 1536
|
| 79 |
+
PAD_IDX = 708
|
| 80 |
+
|
| 81 |
+
model = TransformerWrapper(num_tokens = PAD_IDX+1,
|
| 82 |
+
max_seq_len = SEQ_LEN,
|
| 83 |
+
attn_layers = Decoder(dim = 2048,
|
| 84 |
+
depth = 8,
|
| 85 |
+
heads = 32,
|
| 86 |
+
rotary_pos_emb = True,
|
| 87 |
+
attn_flash = True
|
| 88 |
+
)
|
| 89 |
+
)
|
|
|
|
| 90 |
|
| 91 |
model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
|
| 92 |
|
|
|
|
| 99 |
|
| 100 |
model = torch.compile(model, mode='max-autotune')
|
| 101 |
|
| 102 |
+
model.to(device_type)
|
| 103 |
+
model.eval()
|
| 104 |
+
|
| 105 |
print('=' * 70)
|
| 106 |
print('Done!')
|
| 107 |
print('=' * 70)
|
|
|
|
| 115 |
raw_score = TMIDIX.midi2single_track_ms_score(input_midi)
|
| 116 |
|
| 117 |
escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
|
|
|
|
| 118 |
|
| 119 |
+
sp_escore_notes = TMIDIX.solo_piano_escore_notes(escore_notes)
|
| 120 |
+
zscore = TMIDIX.recalculate_score_timings(sp_escore_notes)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
+
escore = TMIDIX.augment_enhanced_score_notes(zscore, timings_divider=32)
|
| 123 |
|
| 124 |
+
escore = TMIDIX.fix_escore_notes_durations(escore)
|
| 125 |
|
| 126 |
+
cscore = TMIDIX.chordify_score([1000, escore])
|
| 127 |
+
|
| 128 |
+
score = []
|
| 129 |
+
chords = []
|
| 130 |
|
| 131 |
pc = cscore[0]
|
| 132 |
|
| 133 |
for c in cscore:
|
|
|
|
| 134 |
|
| 135 |
+
tones_chord = sorted(set([e[4] % 12 for e in c]))
|
| 136 |
|
| 137 |
+
if tones_chord not in TMIDIX.ALL_CHORDS_SORTED:
|
| 138 |
+
tones_chord = TMIDIX.check_and_fix_tones_chord(tones_chord, use_full_chords=False)
|
| 139 |
+
|
| 140 |
+
chord_tok = TMIDIX.ALL_CHORDS_SORTED.index(tones_chord)
|
| 141 |
+
chords.append(chord_tok+384)
|
| 142 |
+
|
| 143 |
+
score.append(chord_tok+384)
|
| 144 |
+
score.append(max(0, min(127, c[0][1]-pc[0][1])))
|
| 145 |
|
| 146 |
+
for n in c:
|
| 147 |
+
score.extend([max(1, min(127, n[2]))+128, max(1, min(127, n[4]))+256])
|
| 148 |
|
| 149 |
pc = c
|
| 150 |
|
| 151 |
+
print('Done!')
|
| 152 |
+
print('=' * 70)
|
| 153 |
+
print('Score has', len(chords), 'chords')
|
| 154 |
+
print('Score hss', len(score), 'tokens')
|
| 155 |
+
print('=' * 70)
|
| 156 |
|
| 157 |
+
return score, chords
|
| 158 |
|
| 159 |
#==================================================================================
|
| 160 |
|
|
|
|
| 269 |
print('=' * 70)
|
| 270 |
print('Generating...')
|
| 271 |
|
| 272 |
+
|
|
|
|
| 273 |
|
| 274 |
#==================================================================
|
| 275 |
|