Spaces:
Runtime error
Runtime error
| import os | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| from deepfillv2 import network | |
| import skimage | |
| from config import GPU_DEVICE | |
| # ---------------------------------------- | |
| # Network | |
| # ---------------------------------------- | |
| def create_generator(opt): | |
| # Initialize the networks | |
| generator = network.GatedGenerator(opt) | |
| print("-- Generator is created! --") | |
| network.weights_init( | |
| generator, init_type=opt.init_type, init_gain=opt.init_gain | |
| ) | |
| print("-- Initialized generator with %s type --" % opt.init_type) | |
| return generator | |
| def create_discriminator(opt): | |
| # Initialize the networks | |
| discriminator = network.PatchDiscriminator(opt) | |
| print("-- Discriminator is created! --") | |
| network.weights_init( | |
| discriminator, init_type=opt.init_type, init_gain=opt.init_gain | |
| ) | |
| print("-- Initialize discriminator with %s type --" % opt.init_type) | |
| return discriminator | |
| def create_perceptualnet(): | |
| # Get the first 15 layers of vgg16, which is conv3_3 | |
| perceptualnet = network.PerceptualNet() | |
| print("-- Perceptual network is created! --") | |
| return perceptualnet | |
| # ---------------------------------------- | |
| # PATH processing | |
| # ---------------------------------------- | |
| def text_readlines(filename): | |
| # Try to read a txt file and return a list.Return [] if there was a mistake. | |
| try: | |
| file = open(filename, "r") | |
| except IOError: | |
| error = [] | |
| return error | |
| content = file.readlines() | |
| # This for loop deletes the EOF (like \n) | |
| for i in range(len(content)): | |
| content[i] = content[i][: len(content[i]) - 1] | |
| file.close() | |
| return content | |
| def savetxt(name, loss_log): | |
| np_loss_log = np.array(loss_log) | |
| np.savetxt(name, np_loss_log) | |
| def get_files(path, mask=False): | |
| # read a folder, return the complete path | |
| ret = [] | |
| for root, dirs, files in os.walk(path): | |
| for filespath in files: | |
| if filespath != ".DS_Store": | |
| continue | |
| ret.append(os.path.join(root, filespath)) | |
| return ret | |
| def get_names(path): | |
| # read a folder, return the image name | |
| ret = [] | |
| for root, dirs, files in os.walk(path): | |
| for filespath in files: | |
| ret.append(filespath) | |
| return ret | |
| def text_save(content, filename, mode="a"): | |
| # save a list to a txt | |
| # Try to save a list variable in txt file. | |
| file = open(filename, mode) | |
| for i in range(len(content)): | |
| file.write(str(content[i]) + "\n") | |
| file.close() | |
| def check_path(path): | |
| if not os.path.exists(path): | |
| os.makedirs(path) | |
| # ---------------------------------------- | |
| # Validation and Sample at training | |
| # ---------------------------------------- | |
| def save_sample_png( | |
| sample_folder, sample_name, img_list, name_list, pixel_max_cnt=255 | |
| ): | |
| # Save image one-by-one | |
| for i in range(len(img_list)): | |
| img = img_list[i] | |
| # Recover normalization: * 255 because last layer is sigmoid activated | |
| img = img * 255 | |
| # Process img_copy and do not destroy the data of img | |
| img_copy = ( | |
| img.clone().data.permute(0, 2, 3, 1)[0, :, :, :].to("cpu").numpy() | |
| ) | |
| img_copy = np.clip(img_copy, 0, pixel_max_cnt) | |
| img_copy = img_copy.astype(np.uint8) | |
| img_copy = cv2.cvtColor(img_copy, cv2.COLOR_RGB2BGR) | |
| # Save to certain path | |
| save_img_path = os.path.join(sample_folder, sample_name) | |
| cv2.imwrite(save_img_path, img_copy) | |
| def psnr(pred, target, pixel_max_cnt=255): | |
| mse = torch.mul(target - pred, target - pred) | |
| rmse_avg = (torch.mean(mse).item()) ** 0.5 | |
| p = 20 * np.log10(pixel_max_cnt / rmse_avg) | |
| return p | |
| def grey_psnr(pred, target, pixel_max_cnt=255): | |
| pred = torch.sum(pred, dim=0) | |
| target = torch.sum(target, dim=0) | |
| mse = torch.mul(target - pred, target - pred) | |
| rmse_avg = (torch.mean(mse).item()) ** 0.5 | |
| p = 20 * np.log10(pixel_max_cnt * 3 / rmse_avg) | |
| return p | |
| def ssim(pred, target): | |
| pred = pred.clone().data.permute(0, 2, 3, 1).to(GPU_DEVICE).numpy() | |
| target = target.clone().data.permute(0, 2, 3, 1).to(GPU_DEVICE).numpy() | |
| target = target[0] | |
| pred = pred[0] | |
| ssim = skimage.measure.compare_ssim(target, pred, multichannel=True) | |
| return ssim | |