Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| from PIL import Image | |
| import torchvision | |
| from torchvision import transforms | |
| import torch | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from models.modelNetA import Generator as GA | |
| from models.modelNetB import Generator as GB | |
| from models.modelNetC import Generator as GC | |
| scale_size = 128 | |
| scale_sizes = [128, 256, 512] | |
| # load model | |
| modeltype2path = { | |
| 'ModelA': 'DTM_exp_train10%_model_a/g-best.pth', | |
| 'ModelB': 'DTM_exp_train10%_model_b/g-best.pth', | |
| 'ModelC': 'DTM_exp_train10%_model_c/g-best.pth', | |
| } | |
| DEVICE='cpu' | |
| MODELS_TYPE = list(modeltype2path.keys()) | |
| generators = [GA(), GB(), GC()] | |
| for i in range(len(generators)): | |
| generators[i] = torch.nn.DataParallel(generators[i]) | |
| state_dict = torch.load(modeltype2path[MODELS_TYPE[i]], map_location=torch.device('cpu')) | |
| generators[i].load_state_dict(state_dict) | |
| generators[i] = generators[i].module.to(DEVICE) | |
| generators[i].eval() | |
| preprocess = transforms.Compose([ | |
| transforms.Grayscale(), | |
| transforms.ToTensor() | |
| ]) | |
| def predict(input_image, model_name, input_scale_factor): | |
| pil_image = Image.fromarray(input_image.astype('uint8'), 'RGB') | |
| pil_image = transforms.Resize((input_scale_factor, input_scale_factor))(pil_image) | |
| # transform image to torch and do preprocessing | |
| torch_img = preprocess(pil_image).to(DEVICE).unsqueeze(0).to(DEVICE) | |
| torch_img = (torch_img - torch.min(torch_img)) / (torch.max(torch_img) - torch.min(torch_img)) | |
| # model predict | |
| with torch.no_grad(): | |
| output = generators[MODELS_TYPE.index(model_name)](torch_img) | |
| sr, sr_dem_selected = output[0], output[1] | |
| # transform torch to image | |
| sr = sr.squeeze(0).cpu() | |
| torchvision.utils.save_image(sr, 'sr_pred.png') | |
| sr = np.array(Image.open('sr_pred.png')) | |
| sr_dem_selected = sr_dem_selected.squeeze().cpu().detach().numpy() | |
| fig, ax = plt.subplots() | |
| im = ax.imshow(sr_dem_selected, cmap='jet', vmin=0, vmax=np.max(sr_dem_selected)) | |
| plt.colorbar(im, ax=ax) | |
| fig.canvas.draw() | |
| data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
| data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
| # return correct image and info | |
| info = f"{model_name} with {sum(p.numel() for p in generators[MODELS_TYPE.index(model_name)].parameters())} parameters" | |
| return info, sr, data | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| gr.Image(), | |
| gr.inputs.Radio(MODELS_TYPE), | |
| gr.inputs.Radio(scale_sizes) | |
| ], | |
| outputs=[ | |
| gr.Text(label='Model info'), | |
| gr.Image(label='Super Resolution'), | |
| gr.Image(label='DTM') | |
| ], | |
| examples=[ | |
| [f"demo_imgs/{name}", MODELS_TYPE[0], 128] for name in os.listdir('demo_imgs') | |
| ], | |
| title="Super Resolution and DTM Estimation", | |
| description=f"This demo predict Super Resolution and (Super Resolution) DTM from a Grayscale image (if RGB we convert it)." | |
| ) | |
| iface.launch() |