Spaces:
Runtime error
Runtime error
| import torch | |
| import torchvision | |
| from torchvision import transforms | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from models.modelNetA import Generator as GA | |
| from models.modelNetB import Generator as GB | |
| from models.modelNetC import Generator as GC | |
| # DEVICE='cpu' | |
| DEVICE='cuda' | |
| model_type = 'model_c' | |
| modeltype2path = { | |
| 'model_a': 'DTM_exp_train10%_model_a/g-best.pth', | |
| 'model_b': 'DTM_exp_train10%_model_b/g-best.pth', | |
| 'model_c': 'DTM_exp_train10%_model_c/g-best.pth', | |
| } | |
| if model_type == 'model_a': | |
| generator = GA() | |
| if model_type == 'model_b': | |
| generator = GB() | |
| if model_type == 'model_c': | |
| generator = GC() | |
| generator = torch.nn.DataParallel(generator) | |
| state_dict_Gen = torch.load(modeltype2path[model_type], map_location=torch.device('cpu')) | |
| generator.load_state_dict(state_dict_Gen) | |
| generator = generator.module.to(DEVICE) | |
| # generator.to(DEVICE) | |
| generator.eval() | |
| preprocess = transforms.Compose([ | |
| transforms.Grayscale(), | |
| # transforms.Resize((128, 128)), | |
| transforms.ToTensor() | |
| ]) | |
| input_img = Image.open('demo_imgs/fake.jpg') | |
| torch_img = preprocess(input_img).to(DEVICE).unsqueeze(0).to(DEVICE) | |
| torch_img = (torch_img - torch.min(torch_img)) / (torch.max(torch_img) - torch.min(torch_img)) | |
| with torch.no_grad(): | |
| output = generator(torch_img) | |
| sr, sr_dem_selected = output[0], output[1] | |
| sr = sr.squeeze(0).cpu() | |
| print(sr.shape) | |
| torchvision.utils.save_image(sr, 'sr.png') | |
| # sr = Image.fromarray(sr.squeeze(0).detach().numpy() * 255, 'L') | |
| # sr.save('sr2.png') | |
| sr_dem_selected = sr_dem_selected.squeeze().cpu().detach().numpy() | |
| print(sr_dem_selected.shape) | |
| plt.imshow(sr_dem_selected, cmap='jet', vmin=0, vmax=np.max(sr_dem_selected)) | |
| plt.colorbar() | |
| plt.savefig('test.png') |