File size: 5,126 Bytes
47dfee0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch
import torch.nn as nn
import os
import random
import pretty_midi
import processor

from processor import encode_midi, decode_midi

from utilities.argument_funcs import parse_generate_args, print_generate_args
from model.music_transformer import MusicTransformer
from dataset.e_piano import create_epiano_datasets, compute_epiano_accuracy, process_midi
from torch.utils.data import DataLoader
from torch.optim import Adam

from utilities.constants import *
from utilities.device import get_device, use_cuda


# main
def main():
    """
    ----------
    Author: Damon Gwinn
    ----------
    Entry point. Generates music from a model specified by command line arguments
    ----------
    """

    args = parse_generate_args()
    print_generate_args(args)

    if (args.force_cpu):
        use_cuda(False)
        print("WARNING: Forced CPU usage, expect model to perform slower")
        print("")

    os.makedirs(args.output_dir, exist_ok=True)

    # --- MODIFIED LOGIC ---
    # Can be None, an integer index to dataset, or a file path
    if (args.primer_file is None):
        # --- Load dataset ONLY if no primer file is given ---
        print("No primer file provided, loading dataset to pick a random primer...")
        _, _, dataset = create_epiano_datasets(args.midi_root, args.num_prime, random_seq=False)
        # --- --- --- --- --- --- --- --- --- --- --- --- --- ---

        f = str(random.randrange(len(dataset)))
        idx = int(f)
        primer, _ = dataset[idx]
        primer = primer.to(get_device())
        num_primer = primer.shape[0]
        print("Using primer index:", idx, "(", dataset.data_files[idx], ")")

    else:
        # --- Primer file is provided, NO DATASET NEEDED (unless it's an index) ---
        f = args.primer_file

        # --- NEW: Check for "silence" ---
        if (f.lower() == "silence"):
            print("Generating from silence...")
            # Create a primer with one token: a medium velocity (64 // 4 = 16)
            # Velocity START_IDX = 356. Token = 356 + 16 = 372
            primer = torch.tensor([372], dtype=TORCH_LABEL_TYPE, device=get_device())
            num_primer = primer.shape[0]  # This will be 1

        # This part handles if the primer is an integer index (e.g., "3")
        elif (f.isdigit()):
            print("Primer file is an index, loading dataset...")
            # --- Load dataset ONLY if primer is an index ---
            _, _, dataset = create_epiano_datasets(args.midi_root, args.num_prime, random_seq=False)
            # --- --- --- --- --- --- --- --- --- --- --- ---
            idx = int(f)
            primer, _ = dataset[idx]
            primer = primer.to(get_device())
            num_primer = primer.shape[0]
            print("Using primer index:", idx, "(", dataset.data_files[idx], ")")

        # This part handles if the primer is a MIDI file path (e.t., "my_primer.mid")
        else:
            print("Primer file is a MIDI path. Loading and tokenizing...")
            raw_mid = encode_midi(f)
            if (len(raw_mid) == 0):
                print("Error: No midi messages in primer file:", f)
                return

            primer, _ = process_midi(raw_mid, args.num_prime, random_seq=False)
            primer = torch.tensor(primer, dtype=TORCH_LABEL_TYPE, device=get_device())
            num_primer = primer.shape[0]  # Get the actual primer length
            print("Using primer file:", f)

    # --- END MODIFIED LOGIC ---

    model = MusicTransformer(n_layers=args.n_layers, num_heads=args.num_heads,
                             d_model=args.d_model, dim_feedforward=args.dim_feedforward,
                             max_sequence=args.max_sequence, rpr=args.rpr).to(get_device())

    model.load_state_dict(torch.load(args.model_weights))

    # --- MODIFICATION: Don't save a primer if we started from silence ---
    if (args.primer_file.lower() != "silence"):
        f_path = os.path.join(args.output_dir, "primer.mid")
        decode_midi(primer.cpu().numpy(), file_path=f_path)
    # --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- ---

    # GENERATION
    model.eval()
    with torch.set_grad_enabled(False):
        if (args.beam > 0):
            print("BEAM:", args.beam)

            beam_seq = model.generate(primer.unsqueeze(0), args.target_seq_length, beam=args.beam)

            # --- MODIFICATION: Slice the primer ---
            generated_only = beam_seq[0][num_primer:]
            # ------------------------------------

            f_path = os.path.join(args.output_dir, "beam.mid")
            decode_midi(generated_only.cpu().numpy(), file_path=f_path)
        else:
            print("RAND DIST")

            rand_seq = model.generate(primer.unsqueeze(0), args.target_seq_length, beam=0)

            # --- MODIFICATION: Slice the primer ---
            generated_only = rand_seq[0][num_primer:]
            # ------------------------------------

            f_path = os.path.join(args.output_dir, "rand.mid")
            decode_midi(generated_only.cpu().numpy(), file_path=f_path)


if __name__ == "__main__":
    main()