projectlosangeles commited on
Commit
9368837
·
verified ·
1 Parent(s): ef689dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -41
app.py CHANGED
@@ -75,19 +75,18 @@ dtype = 'bfloat16'
75
  ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
76
  ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
77
 
78
- SEQ_LEN = 4096
79
- PAD_IDX = 1794
80
-
81
- model = TransformerWrapper(
82
- num_tokens = PAD_IDX+1,
83
- max_seq_len = SEQ_LEN,
84
- attn_layers = Decoder(dim = 2048,
85
- depth = 4,
86
- heads = 32,
87
- rotary_pos_emb = True,
88
- attn_flash = True
89
- )
90
- )
91
 
92
  model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
93
 
@@ -100,6 +99,9 @@ model.load_state_dict(torch.load(model_checkpoint, map_location=device_type, wei
100
 
101
  model = torch.compile(model, mode='max-autotune')
102
 
 
 
 
103
  print('=' * 70)
104
  print('Done!')
105
  print('=' * 70)
@@ -113,47 +115,46 @@ def load_midi(input_midi, melody_patch=-1, use_nth_note=1):
113
  raw_score = TMIDIX.midi2single_track_ms_score(input_midi)
114
 
115
  escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
116
- escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes, timings_divider=32)
117
 
118
- sp_escore_notes = TMIDIX.solo_piano_escore_notes(escore_notes, keep_drums=False)
119
-
120
- if melody_patch == -1:
121
- zscore = TMIDIX.recalculate_score_timings(sp_escore_notes)
122
-
123
- else:
124
- mel_score = [e for e in sp_escore_notes if e[6] == melody_patch]
125
-
126
- if mel_score:
127
- zscore = TMIDIX.recalculate_score_timings(mel_score)
128
-
129
- else:
130
- zscore = TMIDIX.recalculate_score_timings(sp_escore_notes)
131
 
132
- cscore = TMIDIX.chordify_score([1000, zscore])[:MAX_MELODY_NOTES:use_nth_note]
133
 
134
- score = []
135
 
136
- score_list = []
 
 
 
137
 
138
  pc = cscore[0]
139
 
140
  for c in cscore:
141
- score.append(max(0, min(127, c[0][1]-pc[0][1])))
142
 
143
- scl = [[max(0, min(127, c[0][1]-pc[0][1]))]]
144
 
145
- n = c[0]
146
-
147
- score.extend([max(1, min(127, n[2]))+128, max(1, min(127, n[4]))+256])
148
- scl.append([max(1, min(127, n[2]))+128, max(1, min(127, n[4]))+256])
 
 
 
 
149
 
150
- score_list.append(scl)
 
151
 
152
  pc = c
153
 
154
- score_list.append(scl)
 
 
 
 
155
 
156
- return score, score_list
157
 
158
  #==================================================================================
159
 
@@ -268,8 +269,7 @@ def Generate_Accompaniment(input_midi,
268
  print('=' * 70)
269
  print('Generating...')
270
 
271
- model.to(device_type)
272
- model.eval()
273
 
274
  #==================================================================
275
 
 
75
  ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
76
  ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
77
 
78
+ SEQ_LEN = 1536
79
+ PAD_IDX = 708
80
+
81
+ model = TransformerWrapper(num_tokens = PAD_IDX+1,
82
+ max_seq_len = SEQ_LEN,
83
+ attn_layers = Decoder(dim = 2048,
84
+ depth = 8,
85
+ heads = 32,
86
+ rotary_pos_emb = True,
87
+ attn_flash = True
88
+ )
89
+ )
 
90
 
91
  model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
92
 
 
99
 
100
  model = torch.compile(model, mode='max-autotune')
101
 
102
+ model.to(device_type)
103
+ model.eval()
104
+
105
  print('=' * 70)
106
  print('Done!')
107
  print('=' * 70)
 
115
  raw_score = TMIDIX.midi2single_track_ms_score(input_midi)
116
 
117
  escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
 
118
 
119
+ sp_escore_notes = TMIDIX.solo_piano_escore_notes(escore_notes)
120
+ zscore = TMIDIX.recalculate_score_timings(sp_escore_notes)
 
 
 
 
 
 
 
 
 
 
 
121
 
122
+ escore = TMIDIX.augment_enhanced_score_notes(zscore, timings_divider=32)
123
 
124
+ escore = TMIDIX.fix_escore_notes_durations(escore)
125
 
126
+ cscore = TMIDIX.chordify_score([1000, escore])
127
+
128
+ score = []
129
+ chords = []
130
 
131
  pc = cscore[0]
132
 
133
  for c in cscore:
 
134
 
135
+ tones_chord = sorted(set([e[4] % 12 for e in c]))
136
 
137
+ if tones_chord not in TMIDIX.ALL_CHORDS_SORTED:
138
+ tones_chord = TMIDIX.check_and_fix_tones_chord(tones_chord, use_full_chords=False)
139
+
140
+ chord_tok = TMIDIX.ALL_CHORDS_SORTED.index(tones_chord)
141
+ chords.append(chord_tok+384)
142
+
143
+ score.append(chord_tok+384)
144
+ score.append(max(0, min(127, c[0][1]-pc[0][1])))
145
 
146
+ for n in c:
147
+ score.extend([max(1, min(127, n[2]))+128, max(1, min(127, n[4]))+256])
148
 
149
  pc = c
150
 
151
+ print('Done!')
152
+ print('=' * 70)
153
+ print('Score has', len(chords), 'chords')
154
+ print('Score hss', len(score), 'tokens')
155
+ print('=' * 70)
156
 
157
+ return score, chords
158
 
159
  #==================================================================================
160
 
 
269
  print('=' * 70)
270
  print('Generating...')
271
 
272
+
 
273
 
274
  #==================================================================
275