|
|
import torch |
|
|
import torch.nn as nn |
|
|
import os |
|
|
import random |
|
|
import pretty_midi |
|
|
import processor |
|
|
|
|
|
from processor import encode_midi, decode_midi |
|
|
|
|
|
from utilities.argument_funcs import parse_generate_args, print_generate_args |
|
|
from model.music_transformer import MusicTransformer |
|
|
from dataset.e_piano import create_epiano_datasets, compute_epiano_accuracy, process_midi |
|
|
from torch.utils.data import DataLoader |
|
|
from torch.optim import Adam |
|
|
|
|
|
from utilities.constants import * |
|
|
from utilities.device import get_device, use_cuda |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
""" |
|
|
---------- |
|
|
Author: Damon Gwinn |
|
|
---------- |
|
|
Entry point. Generates music from a model specified by command line arguments |
|
|
---------- |
|
|
""" |
|
|
|
|
|
args = parse_generate_args() |
|
|
print_generate_args(args) |
|
|
|
|
|
if (args.force_cpu): |
|
|
use_cuda(False) |
|
|
print("WARNING: Forced CPU usage, expect model to perform slower") |
|
|
print("") |
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
if (args.primer_file is None): |
|
|
|
|
|
print("No primer file provided, loading dataset to pick a random primer...") |
|
|
_, _, dataset = create_epiano_datasets(args.midi_root, args.num_prime, random_seq=False) |
|
|
|
|
|
|
|
|
f = str(random.randrange(len(dataset))) |
|
|
idx = int(f) |
|
|
primer, _ = dataset[idx] |
|
|
primer = primer.to(get_device()) |
|
|
num_primer = primer.shape[0] |
|
|
print("Using primer index:", idx, "(", dataset.data_files[idx], ")") |
|
|
|
|
|
else: |
|
|
|
|
|
f = args.primer_file |
|
|
|
|
|
|
|
|
if (f.lower() == "silence"): |
|
|
print("Generating from silence...") |
|
|
|
|
|
|
|
|
primer = torch.tensor([372], dtype=TORCH_LABEL_TYPE, device=get_device()) |
|
|
num_primer = primer.shape[0] |
|
|
|
|
|
|
|
|
elif (f.isdigit()): |
|
|
print("Primer file is an index, loading dataset...") |
|
|
|
|
|
_, _, dataset = create_epiano_datasets(args.midi_root, args.num_prime, random_seq=False) |
|
|
|
|
|
idx = int(f) |
|
|
primer, _ = dataset[idx] |
|
|
primer = primer.to(get_device()) |
|
|
num_primer = primer.shape[0] |
|
|
print("Using primer index:", idx, "(", dataset.data_files[idx], ")") |
|
|
|
|
|
|
|
|
else: |
|
|
print("Primer file is a MIDI path. Loading and tokenizing...") |
|
|
raw_mid = encode_midi(f) |
|
|
if (len(raw_mid) == 0): |
|
|
print("Error: No midi messages in primer file:", f) |
|
|
return |
|
|
|
|
|
primer, _ = process_midi(raw_mid, args.num_prime, random_seq=False) |
|
|
primer = torch.tensor(primer, dtype=TORCH_LABEL_TYPE, device=get_device()) |
|
|
num_primer = primer.shape[0] |
|
|
print("Using primer file:", f) |
|
|
|
|
|
|
|
|
|
|
|
model = MusicTransformer(n_layers=args.n_layers, num_heads=args.num_heads, |
|
|
d_model=args.d_model, dim_feedforward=args.dim_feedforward, |
|
|
max_sequence=args.max_sequence, rpr=args.rpr).to(get_device()) |
|
|
|
|
|
model.load_state_dict(torch.load(args.model_weights)) |
|
|
|
|
|
|
|
|
if (args.primer_file.lower() != "silence"): |
|
|
f_path = os.path.join(args.output_dir, "primer.mid") |
|
|
decode_midi(primer.cpu().numpy(), file_path=f_path) |
|
|
|
|
|
|
|
|
|
|
|
model.eval() |
|
|
with torch.set_grad_enabled(False): |
|
|
if (args.beam > 0): |
|
|
print("BEAM:", args.beam) |
|
|
|
|
|
beam_seq = model.generate(primer.unsqueeze(0), args.target_seq_length, beam=args.beam) |
|
|
|
|
|
|
|
|
generated_only = beam_seq[0][num_primer:] |
|
|
|
|
|
|
|
|
f_path = os.path.join(args.output_dir, "beam.mid") |
|
|
decode_midi(generated_only.cpu().numpy(), file_path=f_path) |
|
|
else: |
|
|
print("RAND DIST") |
|
|
|
|
|
rand_seq = model.generate(primer.unsqueeze(0), args.target_seq_length, beam=0) |
|
|
|
|
|
|
|
|
generated_only = rand_seq[0][num_primer:] |
|
|
|
|
|
|
|
|
f_path = os.path.join(args.output_dir, "rand.mid") |
|
|
decode_midi(generated_only.cpu().numpy(), file_path=f_path) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |