Spaces:
Configuration error
Configuration error
| import os | |
| import json | |
| import torch | |
| import numpy as np | |
| import audioldm.hifigan as hifigan | |
| HIFIGAN_16K_64 = { | |
| "resblock": "1", | |
| "num_gpus": 6, | |
| "batch_size": 16, | |
| "learning_rate": 0.0002, | |
| "adam_b1": 0.8, | |
| "adam_b2": 0.99, | |
| "lr_decay": 0.999, | |
| "seed": 1234, | |
| "upsample_rates": [5, 4, 2, 2, 2], | |
| "upsample_kernel_sizes": [16, 16, 8, 4, 4], | |
| "upsample_initial_channel": 1024, | |
| "resblock_kernel_sizes": [3, 7, 11], | |
| "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], | |
| "segment_size": 8192, | |
| "num_mels": 64, | |
| "num_freq": 1025, | |
| "n_fft": 1024, | |
| "hop_size": 160, | |
| "win_size": 1024, | |
| "sampling_rate": 16000, | |
| "fmin": 0, | |
| "fmax": 8000, | |
| "fmax_for_loss": None, | |
| "num_workers": 4, | |
| "dist_config": { | |
| "dist_backend": "nccl", | |
| "dist_url": "tcp://localhost:54321", | |
| "world_size": 1, | |
| }, | |
| } | |
| def get_available_checkpoint_keys(model, ckpt): | |
| print("==> Attemp to reload from %s" % ckpt) | |
| state_dict = torch.load(ckpt)["state_dict"] | |
| current_state_dict = model.state_dict() | |
| new_state_dict = {} | |
| for k in state_dict.keys(): | |
| if ( | |
| k in current_state_dict.keys() | |
| and current_state_dict[k].size() == state_dict[k].size() | |
| ): | |
| new_state_dict[k] = state_dict[k] | |
| else: | |
| print("==> WARNING: Skipping %s" % k) | |
| print( | |
| "%s out of %s keys are matched" | |
| % (len(new_state_dict.keys()), len(state_dict.keys())) | |
| ) | |
| return new_state_dict | |
| def get_param_num(model): | |
| num_param = sum(param.numel() for param in model.parameters()) | |
| return num_param | |
| def get_vocoder(config, device): | |
| config = hifigan.AttrDict(HIFIGAN_16K_64) | |
| vocoder = hifigan.Generator(config) | |
| vocoder.eval() | |
| vocoder.remove_weight_norm() | |
| vocoder.to(device) | |
| return vocoder | |
| def vocoder_infer(mels, vocoder, lengths=None): | |
| vocoder.eval() | |
| with torch.no_grad(): | |
| wavs = vocoder(mels).squeeze(1) | |
| wavs = (wavs.cpu().numpy() * 32768).astype("int16") | |
| if lengths is not None: | |
| wavs = wavs[:, :lengths] | |
| return wavs | |