NanoMaestro / generate.py
utkucoban's picture
NanoMaestro Full model weights released
47dfee0 verified
import torch
import torch.nn as nn
import os
import random
import pretty_midi
import processor
from processor import encode_midi, decode_midi
from utilities.argument_funcs import parse_generate_args, print_generate_args
from model.music_transformer import MusicTransformer
from dataset.e_piano import create_epiano_datasets, compute_epiano_accuracy, process_midi
from torch.utils.data import DataLoader
from torch.optim import Adam
from utilities.constants import *
from utilities.device import get_device, use_cuda
# main
def main():
"""
----------
Author: Damon Gwinn
----------
Entry point. Generates music from a model specified by command line arguments
----------
"""
args = parse_generate_args()
print_generate_args(args)
if (args.force_cpu):
use_cuda(False)
print("WARNING: Forced CPU usage, expect model to perform slower")
print("")
os.makedirs(args.output_dir, exist_ok=True)
# --- MODIFIED LOGIC ---
# Can be None, an integer index to dataset, or a file path
if (args.primer_file is None):
# --- Load dataset ONLY if no primer file is given ---
print("No primer file provided, loading dataset to pick a random primer...")
_, _, dataset = create_epiano_datasets(args.midi_root, args.num_prime, random_seq=False)
# --- --- --- --- --- --- --- --- --- --- --- --- --- ---
f = str(random.randrange(len(dataset)))
idx = int(f)
primer, _ = dataset[idx]
primer = primer.to(get_device())
num_primer = primer.shape[0]
print("Using primer index:", idx, "(", dataset.data_files[idx], ")")
else:
# --- Primer file is provided, NO DATASET NEEDED (unless it's an index) ---
f = args.primer_file
# --- NEW: Check for "silence" ---
if (f.lower() == "silence"):
print("Generating from silence...")
# Create a primer with one token: a medium velocity (64 // 4 = 16)
# Velocity START_IDX = 356. Token = 356 + 16 = 372
primer = torch.tensor([372], dtype=TORCH_LABEL_TYPE, device=get_device())
num_primer = primer.shape[0] # This will be 1
# This part handles if the primer is an integer index (e.g., "3")
elif (f.isdigit()):
print("Primer file is an index, loading dataset...")
# --- Load dataset ONLY if primer is an index ---
_, _, dataset = create_epiano_datasets(args.midi_root, args.num_prime, random_seq=False)
# --- --- --- --- --- --- --- --- --- --- --- ---
idx = int(f)
primer, _ = dataset[idx]
primer = primer.to(get_device())
num_primer = primer.shape[0]
print("Using primer index:", idx, "(", dataset.data_files[idx], ")")
# This part handles if the primer is a MIDI file path (e.t., "my_primer.mid")
else:
print("Primer file is a MIDI path. Loading and tokenizing...")
raw_mid = encode_midi(f)
if (len(raw_mid) == 0):
print("Error: No midi messages in primer file:", f)
return
primer, _ = process_midi(raw_mid, args.num_prime, random_seq=False)
primer = torch.tensor(primer, dtype=TORCH_LABEL_TYPE, device=get_device())
num_primer = primer.shape[0] # Get the actual primer length
print("Using primer file:", f)
# --- END MODIFIED LOGIC ---
model = MusicTransformer(n_layers=args.n_layers, num_heads=args.num_heads,
d_model=args.d_model, dim_feedforward=args.dim_feedforward,
max_sequence=args.max_sequence, rpr=args.rpr).to(get_device())
model.load_state_dict(torch.load(args.model_weights))
# --- MODIFICATION: Don't save a primer if we started from silence ---
if (args.primer_file.lower() != "silence"):
f_path = os.path.join(args.output_dir, "primer.mid")
decode_midi(primer.cpu().numpy(), file_path=f_path)
# --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- ---
# GENERATION
model.eval()
with torch.set_grad_enabled(False):
if (args.beam > 0):
print("BEAM:", args.beam)
beam_seq = model.generate(primer.unsqueeze(0), args.target_seq_length, beam=args.beam)
# --- MODIFICATION: Slice the primer ---
generated_only = beam_seq[0][num_primer:]
# ------------------------------------
f_path = os.path.join(args.output_dir, "beam.mid")
decode_midi(generated_only.cpu().numpy(), file_path=f_path)
else:
print("RAND DIST")
rand_seq = model.generate(primer.unsqueeze(0), args.target_seq_length, beam=0)
# --- MODIFICATION: Slice the primer ---
generated_only = rand_seq[0][num_primer:]
# ------------------------------------
f_path = os.path.join(args.output_dir, "rand.mid")
decode_midi(generated_only.cpu().numpy(), file_path=f_path)
if __name__ == "__main__":
main()