Spaces:
Running
on
T4
Running
on
T4
| import argparse | |
| import os, sys | |
| import torch | |
| import cv2 | |
| from torchvision import transforms | |
| from PIL import Image | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from matplotlib import pyplot as plt | |
| from tqdm import tqdm | |
| # Import files from the local folder | |
| root_path = os.path.abspath('.') | |
| sys.path.append(root_path) | |
| from opt import opt | |
| from dataset_curation_pipeline.IC9600.ICNet import ICNet | |
| inference_transform = transforms.Compose([ | |
| transforms.Resize((512,512)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def blend(ori_img, ic_img, alpha = 0.8, cm = plt.get_cmap("magma")): | |
| cm_ic_map = cm(ic_img) | |
| heatmap = Image.fromarray((cm_ic_map[:, :, -2::-1]*255).astype(np.uint8)) | |
| ori_img = Image.fromarray(ori_img) | |
| blend = Image.blend(ori_img,heatmap,alpha=alpha) | |
| blend = np.array(blend) | |
| return blend | |
| def infer_one_image(model, img_path): | |
| with torch.no_grad(): | |
| ori_img = Image.open(img_path).convert("RGB") | |
| ori_height = ori_img.height | |
| ori_width = ori_img.width | |
| img = inference_transform(ori_img) | |
| img = img.cuda() | |
| img = img.unsqueeze(0) | |
| ic_score, ic_map = model(img) | |
| ic_score = ic_score.item() | |
| # ic_map = F.interpolate(ic_map, (ori_height, ori_width), mode = 'bilinear') | |
| ## gene ic map | |
| # ic_map_np = ic_map.squeeze().detach().cpu().numpy() | |
| # out_ic_map_name = os.path.basename(img_path).split('.')[0] + '_' + str(ic_score)[:7] + '.npy' | |
| # out_ic_map_path = os.path.join(args.output, out_ic_map_name) | |
| # np.save(out_ic_map_path, ic_map_np) | |
| ## gene blend map | |
| # ic_map_img = (ic_map * 255).round().squeeze().detach().cpu().numpy().astype('uint8') | |
| # blend_img = blend(np.array(ori_img), ic_map_img) | |
| # out_blend_img_name = os.path.basename(img_path).split('.')[0] + '.png' | |
| # out_blend_img_path = os.path.join(args.output, out_blend_img_name) | |
| # cv2.imwrite(out_blend_img_path, blend_img) | |
| return ic_score | |
| def infer_directory(img_dir): | |
| imgs = sorted(os.listdir(img_dir)) | |
| scores = [] | |
| for img in tqdm(imgs): | |
| img_path = os.path.join(img_dir, img) | |
| score = infer_one_image(img_path) | |
| scores.append((score, img_path)) | |
| print(img_path, score) | |
| scores = sorted(scores, key=lambda x: x[0]) | |
| scores = scores[::-1] | |
| for score in scores[:50]: | |
| print(score) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('-i', '--input', type = str, default = './example') | |
| parser.add_argument('-o', '--output', type = str, default = './out') | |
| parser.add_argument('-d', '--device', type = int, default=0) | |
| args = parser.parse_args() | |
| model = ICNet() | |
| model.load_state_dict(torch.load('./checkpoint/ck.pth',map_location=torch.device('cpu'))) | |
| model.eval() | |
| device = torch.device(args.device) | |
| model.to(device) | |
| inference_transform = transforms.Compose([ | |
| transforms.Resize((512,512)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| if os.path.isfile(args.input): | |
| infer_one_image(args.input) | |
| else: | |
| infer_directory(args.input) | |