import torch import torch.nn as nn import os import random import pretty_midi import processor from processor import encode_midi, decode_midi from utilities.argument_funcs import parse_generate_args, print_generate_args from model.music_transformer import MusicTransformer from dataset.e_piano import create_epiano_datasets, compute_epiano_accuracy, process_midi from torch.utils.data import DataLoader from torch.optim import Adam from utilities.constants import * from utilities.device import get_device, use_cuda # main def main(): """ ---------- Author: Damon Gwinn ---------- Entry point. Generates music from a model specified by command line arguments ---------- """ args = parse_generate_args() print_generate_args(args) if (args.force_cpu): use_cuda(False) print("WARNING: Forced CPU usage, expect model to perform slower") print("") os.makedirs(args.output_dir, exist_ok=True) # --- MODIFIED LOGIC --- # Can be None, an integer index to dataset, or a file path if (args.primer_file is None): # --- Load dataset ONLY if no primer file is given --- print("No primer file provided, loading dataset to pick a random primer...") _, _, dataset = create_epiano_datasets(args.midi_root, args.num_prime, random_seq=False) # --- --- --- --- --- --- --- --- --- --- --- --- --- --- f = str(random.randrange(len(dataset))) idx = int(f) primer, _ = dataset[idx] primer = primer.to(get_device()) num_primer = primer.shape[0] print("Using primer index:", idx, "(", dataset.data_files[idx], ")") else: # --- Primer file is provided, NO DATASET NEEDED (unless it's an index) --- f = args.primer_file # --- NEW: Check for "silence" --- if (f.lower() == "silence"): print("Generating from silence...") # Create a primer with one token: a medium velocity (64 // 4 = 16) # Velocity START_IDX = 356. Token = 356 + 16 = 372 primer = torch.tensor([372], dtype=TORCH_LABEL_TYPE, device=get_device()) num_primer = primer.shape[0] # This will be 1 # This part handles if the primer is an integer index (e.g., "3") elif (f.isdigit()): print("Primer file is an index, loading dataset...") # --- Load dataset ONLY if primer is an index --- _, _, dataset = create_epiano_datasets(args.midi_root, args.num_prime, random_seq=False) # --- --- --- --- --- --- --- --- --- --- --- --- idx = int(f) primer, _ = dataset[idx] primer = primer.to(get_device()) num_primer = primer.shape[0] print("Using primer index:", idx, "(", dataset.data_files[idx], ")") # This part handles if the primer is a MIDI file path (e.t., "my_primer.mid") else: print("Primer file is a MIDI path. Loading and tokenizing...") raw_mid = encode_midi(f) if (len(raw_mid) == 0): print("Error: No midi messages in primer file:", f) return primer, _ = process_midi(raw_mid, args.num_prime, random_seq=False) primer = torch.tensor(primer, dtype=TORCH_LABEL_TYPE, device=get_device()) num_primer = primer.shape[0] # Get the actual primer length print("Using primer file:", f) # --- END MODIFIED LOGIC --- model = MusicTransformer(n_layers=args.n_layers, num_heads=args.num_heads, d_model=args.d_model, dim_feedforward=args.dim_feedforward, max_sequence=args.max_sequence, rpr=args.rpr).to(get_device()) model.load_state_dict(torch.load(args.model_weights)) # --- MODIFICATION: Don't save a primer if we started from silence --- if (args.primer_file.lower() != "silence"): f_path = os.path.join(args.output_dir, "primer.mid") decode_midi(primer.cpu().numpy(), file_path=f_path) # --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- # GENERATION model.eval() with torch.set_grad_enabled(False): if (args.beam > 0): print("BEAM:", args.beam) beam_seq = model.generate(primer.unsqueeze(0), args.target_seq_length, beam=args.beam) # --- MODIFICATION: Slice the primer --- generated_only = beam_seq[0][num_primer:] # ------------------------------------ f_path = os.path.join(args.output_dir, "beam.mid") decode_midi(generated_only.cpu().numpy(), file_path=f_path) else: print("RAND DIST") rand_seq = model.generate(primer.unsqueeze(0), args.target_seq_length, beam=0) # --- MODIFICATION: Slice the primer --- generated_only = rand_seq[0][num_primer:] # ------------------------------------ f_path = os.path.join(args.output_dir, "rand.mid") decode_midi(generated_only.cpu().numpy(), file_path=f_path) if __name__ == "__main__": main()