File size: 5,126 Bytes
47dfee0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import torch
import torch.nn as nn
import os
import random
import pretty_midi
import processor
from processor import encode_midi, decode_midi
from utilities.argument_funcs import parse_generate_args, print_generate_args
from model.music_transformer import MusicTransformer
from dataset.e_piano import create_epiano_datasets, compute_epiano_accuracy, process_midi
from torch.utils.data import DataLoader
from torch.optim import Adam
from utilities.constants import *
from utilities.device import get_device, use_cuda
# main
def main():
"""
----------
Author: Damon Gwinn
----------
Entry point. Generates music from a model specified by command line arguments
----------
"""
args = parse_generate_args()
print_generate_args(args)
if (args.force_cpu):
use_cuda(False)
print("WARNING: Forced CPU usage, expect model to perform slower")
print("")
os.makedirs(args.output_dir, exist_ok=True)
# --- MODIFIED LOGIC ---
# Can be None, an integer index to dataset, or a file path
if (args.primer_file is None):
# --- Load dataset ONLY if no primer file is given ---
print("No primer file provided, loading dataset to pick a random primer...")
_, _, dataset = create_epiano_datasets(args.midi_root, args.num_prime, random_seq=False)
# --- --- --- --- --- --- --- --- --- --- --- --- --- ---
f = str(random.randrange(len(dataset)))
idx = int(f)
primer, _ = dataset[idx]
primer = primer.to(get_device())
num_primer = primer.shape[0]
print("Using primer index:", idx, "(", dataset.data_files[idx], ")")
else:
# --- Primer file is provided, NO DATASET NEEDED (unless it's an index) ---
f = args.primer_file
# --- NEW: Check for "silence" ---
if (f.lower() == "silence"):
print("Generating from silence...")
# Create a primer with one token: a medium velocity (64 // 4 = 16)
# Velocity START_IDX = 356. Token = 356 + 16 = 372
primer = torch.tensor([372], dtype=TORCH_LABEL_TYPE, device=get_device())
num_primer = primer.shape[0] # This will be 1
# This part handles if the primer is an integer index (e.g., "3")
elif (f.isdigit()):
print("Primer file is an index, loading dataset...")
# --- Load dataset ONLY if primer is an index ---
_, _, dataset = create_epiano_datasets(args.midi_root, args.num_prime, random_seq=False)
# --- --- --- --- --- --- --- --- --- --- --- ---
idx = int(f)
primer, _ = dataset[idx]
primer = primer.to(get_device())
num_primer = primer.shape[0]
print("Using primer index:", idx, "(", dataset.data_files[idx], ")")
# This part handles if the primer is a MIDI file path (e.t., "my_primer.mid")
else:
print("Primer file is a MIDI path. Loading and tokenizing...")
raw_mid = encode_midi(f)
if (len(raw_mid) == 0):
print("Error: No midi messages in primer file:", f)
return
primer, _ = process_midi(raw_mid, args.num_prime, random_seq=False)
primer = torch.tensor(primer, dtype=TORCH_LABEL_TYPE, device=get_device())
num_primer = primer.shape[0] # Get the actual primer length
print("Using primer file:", f)
# --- END MODIFIED LOGIC ---
model = MusicTransformer(n_layers=args.n_layers, num_heads=args.num_heads,
d_model=args.d_model, dim_feedforward=args.dim_feedforward,
max_sequence=args.max_sequence, rpr=args.rpr).to(get_device())
model.load_state_dict(torch.load(args.model_weights))
# --- MODIFICATION: Don't save a primer if we started from silence ---
if (args.primer_file.lower() != "silence"):
f_path = os.path.join(args.output_dir, "primer.mid")
decode_midi(primer.cpu().numpy(), file_path=f_path)
# --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- ---
# GENERATION
model.eval()
with torch.set_grad_enabled(False):
if (args.beam > 0):
print("BEAM:", args.beam)
beam_seq = model.generate(primer.unsqueeze(0), args.target_seq_length, beam=args.beam)
# --- MODIFICATION: Slice the primer ---
generated_only = beam_seq[0][num_primer:]
# ------------------------------------
f_path = os.path.join(args.output_dir, "beam.mid")
decode_midi(generated_only.cpu().numpy(), file_path=f_path)
else:
print("RAND DIST")
rand_seq = model.generate(primer.unsqueeze(0), args.target_seq_length, beam=0)
# --- MODIFICATION: Slice the primer ---
generated_only = rand_seq[0][num_primer:]
# ------------------------------------
f_path = os.path.join(args.output_dir, "rand.mid")
decode_midi(generated_only.cpu().numpy(), file_path=f_path)
if __name__ == "__main__":
main() |