Spaces:
Running
Running
| import re | |
| import os | |
| import time | |
| import torch | |
| import random | |
| import shutil | |
| import argparse | |
| import soundfile as sf | |
| from transformers import GPT2Config | |
| from model import Patchilizer, TunesFormer | |
| from convert import abc2xml, xml2img, xml2, transpose_octaves_abc | |
| from utils import ( | |
| PATCH_NUM_LAYERS, | |
| PATCH_LENGTH, | |
| CHAR_NUM_LAYERS, | |
| PATCH_SIZE, | |
| SHARE_WEIGHTS, | |
| TEMP_DIR, | |
| DEVICE, | |
| ) | |
| def get_args(parser: argparse.ArgumentParser): | |
| parser.add_argument( | |
| "-num_tunes", | |
| type=int, | |
| default=1, | |
| help="the number of independently computed returned tunes", | |
| ) | |
| parser.add_argument( | |
| "-max_patch", | |
| type=int, | |
| default=128, | |
| help="integer to define the maximum length in tokens of each tune", | |
| ) | |
| parser.add_argument( | |
| "-top_p", | |
| type=float, | |
| default=0.8, | |
| help="float to define the tokens that are within the sample operation of text generation", | |
| ) | |
| parser.add_argument( | |
| "-top_k", | |
| type=int, | |
| default=8, | |
| help="integer to define the tokens that are within the sample operation of text generation", | |
| ) | |
| parser.add_argument( | |
| "-temperature", | |
| type=float, | |
| default=1.2, | |
| help="the temperature of the sampling operation", | |
| ) | |
| parser.add_argument("-seed", type=int, default=None, help="seed for randomstate") | |
| parser.add_argument( | |
| "-show_control_code", | |
| type=bool, | |
| default=False, | |
| help="whether to show control code", | |
| ) | |
| parser.add_argument( | |
| "-template", | |
| type=bool, | |
| default=True, | |
| help="whether to generate by template", | |
| ) | |
| return parser.parse_args() | |
| def get_abc_key_val(text: str, key="K"): | |
| pattern = re.escape(key) + r":(.*?)\n" | |
| match = re.search(pattern, text) | |
| if match: | |
| return match.group(1).strip() | |
| else: | |
| return None | |
| def adjust_volume(in_audio: str, dB_change: int): | |
| y, sr = sf.read(in_audio) | |
| sf.write(in_audio, y * 10 ** (dB_change / 20), sr) | |
| def clean_dir(dir_path: str): | |
| if os.path.exists(dir_path): | |
| shutil.rmtree(dir_path) | |
| os.makedirs(dir_path) | |
| def generate_music( | |
| args, | |
| emo: str, | |
| weights: str, | |
| outdir=f"{TEMP_DIR}/output", | |
| fix_tempo=None, | |
| fix_pitch=None, | |
| fix_volume=None, | |
| ): | |
| clean_dir(outdir) | |
| patchilizer = Patchilizer() | |
| patch_config = GPT2Config( | |
| num_hidden_layers=PATCH_NUM_LAYERS, | |
| max_length=PATCH_LENGTH, | |
| max_position_embeddings=PATCH_LENGTH, | |
| vocab_size=1, | |
| ) | |
| char_config = GPT2Config( | |
| num_hidden_layers=CHAR_NUM_LAYERS, | |
| max_length=PATCH_SIZE, | |
| max_position_embeddings=PATCH_SIZE, | |
| vocab_size=128, | |
| ) | |
| model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS) | |
| checkpoint = torch.load(weights, map_location=DEVICE) | |
| model.load_state_dict(checkpoint["model"]) | |
| model = model.to(DEVICE) | |
| model.eval() | |
| prompt = f"A:{emo}\n" | |
| tunes = "" | |
| num_tunes = args.num_tunes | |
| max_patch = args.max_patch | |
| top_p = args.top_p | |
| top_k = args.top_k | |
| temperature = args.temperature | |
| seed = args.seed | |
| show_control_code = args.show_control_code | |
| fname_prefix = emo if args.template else "Melody" | |
| print(" Hyper parms ".center(60, "#"), "\n") | |
| args_dict: dict = vars(args) | |
| for arg in args_dict.keys(): | |
| print(f"{arg}: {str(args_dict[arg])}") | |
| print("\n", " Output tunes ".center(60, "#")) | |
| start_time = time.time() | |
| for i in range(num_tunes): | |
| title = f"T:{fname_prefix} Fragment\n" | |
| artist = f"C:Generated by AI\n" | |
| tune = f"X:{str(i + 1)}\n{title}{artist}{prompt}" | |
| lines = re.split(r"(\n)", tune) | |
| tune = "" | |
| skip = False | |
| for line in lines: | |
| if show_control_code or line[:2] not in ["S:", "B:", "E:", "D:"]: | |
| if not skip: | |
| print(line, end="") | |
| tune += line | |
| skip = False | |
| else: | |
| skip = True | |
| input_patches = torch.tensor( | |
| [patchilizer.encode(prompt, add_special_patches=True)[:-1]], | |
| device=DEVICE, | |
| ) | |
| if tune == "": | |
| tokens = None | |
| else: | |
| prefix = patchilizer.decode(input_patches[0]) | |
| remaining_tokens = prompt[len(prefix) :] | |
| tokens = torch.tensor( | |
| [patchilizer.bos_token_id] + [ord(c) for c in remaining_tokens], | |
| device=DEVICE, | |
| ) | |
| while input_patches.shape[1] < max_patch: | |
| predicted_patch, seed = model.generate( | |
| input_patches, | |
| tokens, | |
| top_p=top_p, | |
| top_k=top_k, | |
| temperature=temperature, | |
| seed=seed, | |
| ) | |
| tokens = None | |
| if predicted_patch[0] != patchilizer.eos_token_id: | |
| next_bar = patchilizer.decode([predicted_patch]) | |
| if show_control_code or next_bar[:2] not in ["S:", "B:", "E:", "D:"]: | |
| print(next_bar, end="") | |
| tune += next_bar | |
| if next_bar == "": | |
| break | |
| next_bar = remaining_tokens + next_bar | |
| remaining_tokens = "" | |
| predicted_patch = torch.tensor( | |
| patchilizer.bar2patch(next_bar), | |
| device=DEVICE, | |
| ).unsqueeze(0) | |
| input_patches = torch.cat( | |
| [input_patches, predicted_patch.unsqueeze(0)], | |
| dim=1, | |
| ) | |
| else: | |
| break | |
| tunes += f"{tune}\n\n" | |
| print("\n") | |
| # fix tempo | |
| if fix_tempo != None: | |
| tempo = f"Q:{fix_tempo}\n" | |
| else: | |
| tempo = f"Q:{random.randint(88, 132)}\n" | |
| if emo == "Q1": | |
| tempo = f"Q:{random.randint(160, 184)}\n" | |
| elif emo == "Q2": | |
| tempo = f"Q:{random.randint(184, 228)}\n" | |
| elif emo == "Q3": | |
| tempo = f"Q:{random.randint(40, 69)}\n" | |
| elif emo == "Q4": | |
| tempo = f"Q:{random.randint(40, 69)}\n" | |
| Q_val = get_abc_key_val(tunes, "Q") | |
| if Q_val: | |
| tunes = tunes.replace(f"Q:{Q_val}\n", "") | |
| K_val = get_abc_key_val(tunes) | |
| if K_val == "none": | |
| K_val = "C" | |
| tunes = tunes.replace("K:none\n", f"K:{K_val}\n") | |
| tunes = tunes.replace(f"A:{emo}\n", tempo) | |
| mode = "major" if emo == "Q1" or emo == "Q4" else "minor" # fix mode:major/minor | |
| if (mode == "major") and ("m" in K_val): | |
| tunes = tunes.replace(f"\nK:{K_val}\n", f"\nK:{K_val.split('m')[0]}\n") | |
| elif (mode == "minor") and (not "m" in K_val): | |
| tunes = tunes.replace(f"\nK:{K_val}\n", f"\nK:{K_val.replace('dor', '')}min\n") | |
| print("Generation time: {:.2f} seconds".format(time.time() - start_time)) | |
| timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime()) | |
| try: | |
| if fix_pitch != None: # fix avg_pitch (octave) | |
| if fix_pitch: | |
| tunes, xml = transpose_octaves_abc( | |
| tunes, | |
| f"{outdir}/{timestamp}.musicxml", | |
| fix_pitch, | |
| ) | |
| tunes = tunes.replace(title + title, title) | |
| os.rename(xml, f"{outdir}/[{fname_prefix}]{timestamp}.musicxml") | |
| xml = f"{outdir}/[{fname_prefix}]{timestamp}.musicxml" | |
| else: | |
| if mode == "minor": | |
| offset = -12 | |
| if emo == "Q2": | |
| offset -= 12 | |
| tunes, xml = transpose_octaves_abc( | |
| tunes, | |
| f"{outdir}/{timestamp}.musicxml", | |
| offset, | |
| ) | |
| tunes = tunes.replace(title + title, title) | |
| os.rename(xml, f"{outdir}/[{fname_prefix}]{timestamp}.musicxml") | |
| xml = f"{outdir}/[{fname_prefix}]{timestamp}.musicxml" | |
| else: | |
| xml = abc2xml(tunes, f"{outdir}/[{fname_prefix}]{timestamp}.musicxml") | |
| audio = xml2(xml, "wav") | |
| if fix_volume != None: | |
| if fix_volume: | |
| adjust_volume(audio, fix_volume) | |
| elif os.path.exists(audio): | |
| if emo == "Q1": | |
| adjust_volume(audio, 5) | |
| elif emo == "Q2": | |
| adjust_volume(audio, 10) | |
| mxl = xml2(xml, "mxl") | |
| midi = xml2(xml, "mid") | |
| pdf, jpg = xml2img(xml) | |
| return audio, midi, pdf, xml, mxl, tunes, jpg | |
| except Exception as e: | |
| print(f"{e}") | |
| return generate_music(args, emo, weights) | |