Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -110,7 +110,7 @@ print('=' * 70)
|
|
| 110 |
|
| 111 |
#==================================================================================
|
| 112 |
|
| 113 |
-
def load_midi(input_midi
|
| 114 |
|
| 115 |
raw_score = TMIDIX.midi2single_track_ms_score(input_midi)
|
| 116 |
|
|
@@ -159,61 +159,13 @@ def load_midi(input_midi, melody_patch=-1, use_nth_note=1):
|
|
| 159 |
#==================================================================================
|
| 160 |
|
| 161 |
@spaces.GPU
|
| 162 |
-
def
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
#===============================================================================
|
| 171 |
-
|
| 172 |
-
def generate_full_seq(input_seq,
|
| 173 |
-
max_toks=3072,
|
| 174 |
-
temperature=0.9,
|
| 175 |
-
top_k_value=15,
|
| 176 |
-
verbose=True
|
| 177 |
-
):
|
| 178 |
-
|
| 179 |
-
seq_abs_run_time = sum([t for t in input_seq if t < 128])
|
| 180 |
-
|
| 181 |
-
cur_time = 0
|
| 182 |
-
|
| 183 |
-
full_seq = copy.deepcopy(input_seq)
|
| 184 |
-
|
| 185 |
-
toks_counter = 0
|
| 186 |
-
|
| 187 |
-
while cur_time <= seq_abs_run_time+32:
|
| 188 |
-
|
| 189 |
-
if verbose:
|
| 190 |
-
if toks_counter % 128 == 0:
|
| 191 |
-
print('Generated', toks_counter, 'tokens')
|
| 192 |
-
|
| 193 |
-
x = torch.LongTensor(full_seq).cuda()
|
| 194 |
-
|
| 195 |
-
with ctx:
|
| 196 |
-
out = model.generate(x,
|
| 197 |
-
1,
|
| 198 |
-
filter_logits_fn=top_k,
|
| 199 |
-
filter_kwargs={'k': top_k_value},
|
| 200 |
-
temperature=temperature,
|
| 201 |
-
return_prime=False,
|
| 202 |
-
verbose=False)
|
| 203 |
-
|
| 204 |
-
y = out.tolist()[0][0]
|
| 205 |
-
|
| 206 |
-
if y < 128:
|
| 207 |
-
cur_time += y
|
| 208 |
-
|
| 209 |
-
full_seq.append(y)
|
| 210 |
-
|
| 211 |
-
toks_counter += 1
|
| 212 |
-
|
| 213 |
-
if toks_counter == max_toks:
|
| 214 |
-
return full_seq
|
| 215 |
-
|
| 216 |
-
return full_seq
|
| 217 |
|
| 218 |
#===============================================================================
|
| 219 |
|
|
@@ -225,65 +177,45 @@ def Generate_Accompaniment(input_midi,
|
|
| 225 |
print('=' * 70)
|
| 226 |
print('Requested settings:')
|
| 227 |
print('=' * 70)
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
print('Input MIDI file name:', fn)
|
| 232 |
-
|
| 233 |
-
else:
|
| 234 |
-
print('Input sample melody:', input_melody)
|
| 235 |
print('Source melody patch:', melody_patch)
|
| 236 |
print('Use nth melody note:', use_nth_note)
|
| 237 |
print('Model temperature:', model_temperature)
|
| 238 |
-
print('Model top
|
| 239 |
|
| 240 |
print('=' * 70)
|
| 241 |
|
| 242 |
#==================================================================
|
| 243 |
|
| 244 |
-
print('
|
| 245 |
-
|
| 246 |
-
if input_midi:
|
| 247 |
-
inp_mel = 'Custom MIDI'
|
| 248 |
-
score, score_list = load_midi(input_midi.name, melody_patch, use_nth_note)
|
| 249 |
-
|
| 250 |
-
else:
|
| 251 |
-
mel_list = [m[0].lower() for m in popular_hook_melodies]
|
| 252 |
|
| 253 |
-
|
| 254 |
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
inp_mel = m.title()
|
| 258 |
-
break
|
| 259 |
-
|
| 260 |
-
score = popular_hook_melodies[[m[0] for m in popular_hook_melodies].index(inp_mel)][1]
|
| 261 |
-
score_list = [[[score[i]], score[i+1:i+3]] for i in range(0, len(score)-3, 3)]
|
| 262 |
-
|
| 263 |
-
print('Selected melody:', inp_mel)
|
| 264 |
-
|
| 265 |
-
print('Sample score events', score[:12])
|
| 266 |
|
| 267 |
#==================================================================
|
| 268 |
|
| 269 |
print('=' * 70)
|
| 270 |
print('Generating...')
|
| 271 |
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
|
| 278 |
#==================================================================
|
| 279 |
-
|
| 280 |
-
input_seq = generate_full_seq(start_score_seq,
|
| 281 |
-
max_toks=MAX_GEN_TOKS,
|
| 282 |
-
temperature=model_temperature,
|
| 283 |
-
top_k_value=model_sampling_top_k,
|
| 284 |
-
)
|
| 285 |
-
|
| 286 |
-
final_song = input_seq[len(start_score_seq):]
|
| 287 |
|
| 288 |
print('=' * 70)
|
| 289 |
print('Done!')
|
|
@@ -392,7 +324,7 @@ with gr.Blocks() as demo:
|
|
| 392 |
#==================================================================================
|
| 393 |
|
| 394 |
gr.Markdown("## Upload source melody MIDI or enter a search query for a sample melody below")
|
| 395 |
-
gr.Markdown("###
|
| 396 |
|
| 397 |
input_midi = gr.File(label="Input MIDI",
|
| 398 |
file_types=[".midi", ".mid", ".kar"]
|
|
@@ -400,10 +332,8 @@ with gr.Blocks() as demo:
|
|
| 400 |
|
| 401 |
gr.Markdown("## Generation options")
|
| 402 |
|
| 403 |
-
melody_patch = gr.Slider(-1, 127, value=-1, step=1, label="Source melody MIDI patch")
|
| 404 |
-
use_nth_note = gr.Slider(1, 8, value=1, step=1, label="Use each nth melody note")
|
| 405 |
model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
|
| 406 |
-
|
| 407 |
|
| 408 |
generate_btn = gr.Button("Generate", variant="primary")
|
| 409 |
|
|
@@ -414,12 +344,10 @@ with gr.Blocks() as demo:
|
|
| 414 |
output_plot = gr.Plot(label="MIDI score plot")
|
| 415 |
output_midi = gr.File(label="MIDI file", file_types=[".mid"])
|
| 416 |
|
| 417 |
-
generate_btn.click(
|
| 418 |
[input_midi,
|
| 419 |
-
melody_patch,
|
| 420 |
-
use_nth_note,
|
| 421 |
model_temperature,
|
| 422 |
-
|
| 423 |
],
|
| 424 |
[output_audio,
|
| 425 |
output_plot,
|
|
@@ -428,19 +356,17 @@ with gr.Blocks() as demo:
|
|
| 428 |
)
|
| 429 |
|
| 430 |
gr.Examples(
|
| 431 |
-
[["Sharing The Night Together.kar",
|
| 432 |
],
|
| 433 |
[input_midi,
|
| 434 |
-
melody_patch,
|
| 435 |
-
use_nth_note,
|
| 436 |
model_temperature,
|
| 437 |
-
|
| 438 |
],
|
| 439 |
[output_audio,
|
| 440 |
output_plot,
|
| 441 |
output_midi
|
| 442 |
],
|
| 443 |
-
|
| 444 |
)
|
| 445 |
|
| 446 |
#==================================================================================
|
|
|
|
| 110 |
|
| 111 |
#==================================================================================
|
| 112 |
|
| 113 |
+
def load_midi(input_midi):
|
| 114 |
|
| 115 |
raw_score = TMIDIX.midi2single_track_ms_score(input_midi)
|
| 116 |
|
|
|
|
| 159 |
#==================================================================================
|
| 160 |
|
| 161 |
@spaces.GPU
|
| 162 |
+
def Generate_Chords_Textures(input_midi,
|
| 163 |
+
input_melody,
|
| 164 |
+
melody_patch,
|
| 165 |
+
use_nth_note,
|
| 166 |
+
model_temperature,
|
| 167 |
+
model_sampling_top_p
|
| 168 |
+
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
#===============================================================================
|
| 171 |
|
|
|
|
| 177 |
print('=' * 70)
|
| 178 |
print('Requested settings:')
|
| 179 |
print('=' * 70)
|
| 180 |
+
fn = os.path.basename(input_midi)
|
| 181 |
+
fn1 = fn.split('.')[0]
|
| 182 |
+
print('Input MIDI file name:', fn)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
print('Source melody patch:', melody_patch)
|
| 184 |
print('Use nth melody note:', use_nth_note)
|
| 185 |
print('Model temperature:', model_temperature)
|
| 186 |
+
print('Model top p:', model_sampling_top_p)
|
| 187 |
|
| 188 |
print('=' * 70)
|
| 189 |
|
| 190 |
#==================================================================
|
| 191 |
|
| 192 |
+
print('Loading MIDI...')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
+
score, chords = load_midi(input_midi)
|
| 195 |
|
| 196 |
+
print('Sample score chords', chords[:10])
|
| 197 |
+
print('Sample score tokens', score[:10])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
#==================================================================
|
| 200 |
|
| 201 |
print('=' * 70)
|
| 202 |
print('Generating...')
|
| 203 |
|
| 204 |
+
x = torch.LongTensor([705] + chords[:128] + [706]).cuda()
|
| 205 |
+
|
| 206 |
+
with ctx:
|
| 207 |
+
out = model.generate(x,
|
| 208 |
+
1024,
|
| 209 |
+
temperature=model_temperature,
|
| 210 |
+
filter_logits_fn=top_p,
|
| 211 |
+
filter_kwargs={'thres': model_sampling_top_p},
|
| 212 |
+
return_prime=False,
|
| 213 |
+
eos_token=707,
|
| 214 |
+
verbose=False)
|
| 215 |
+
|
| 216 |
+
final_song = out.tolist()
|
| 217 |
|
| 218 |
#==================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
print('=' * 70)
|
| 221 |
print('Done!')
|
|
|
|
| 324 |
#==================================================================================
|
| 325 |
|
| 326 |
gr.Markdown("## Upload source melody MIDI or enter a search query for a sample melody below")
|
| 327 |
+
gr.Markdown("### PLEASE NOTE: The demo is limited and will only texture first 128 chords of the MIDI file")
|
| 328 |
|
| 329 |
input_midi = gr.File(label="Input MIDI",
|
| 330 |
file_types=[".midi", ".mid", ".kar"]
|
|
|
|
| 332 |
|
| 333 |
gr.Markdown("## Generation options")
|
| 334 |
|
|
|
|
|
|
|
| 335 |
model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
|
| 336 |
+
model_sampling_top_p = gr.Slider(0.1, 0.99, value=0.96, step=0.01, label="Model sampling top p value")
|
| 337 |
|
| 338 |
generate_btn = gr.Button("Generate", variant="primary")
|
| 339 |
|
|
|
|
| 344 |
output_plot = gr.Plot(label="MIDI score plot")
|
| 345 |
output_midi = gr.File(label="MIDI file", file_types=[".mid"])
|
| 346 |
|
| 347 |
+
generate_btn.click(Generate_Chords_Textures,
|
| 348 |
[input_midi,
|
|
|
|
|
|
|
| 349 |
model_temperature,
|
| 350 |
+
model_sampling_top_p
|
| 351 |
],
|
| 352 |
[output_audio,
|
| 353 |
output_plot,
|
|
|
|
| 356 |
)
|
| 357 |
|
| 358 |
gr.Examples(
|
| 359 |
+
[["Sharing The Night Together.kar", 0.9, 0.96]
|
| 360 |
],
|
| 361 |
[input_midi,
|
|
|
|
|
|
|
| 362 |
model_temperature,
|
| 363 |
+
model_sampling_top_p
|
| 364 |
],
|
| 365 |
[output_audio,
|
| 366 |
output_plot,
|
| 367 |
output_midi
|
| 368 |
],
|
| 369 |
+
Generate_Chords_Textures
|
| 370 |
)
|
| 371 |
|
| 372 |
#==================================================================================
|