Spaces:
Sleeping
Sleeping
| """Compute segmentation maps for images in the input folder. | |
| """ | |
| import os | |
| import glob | |
| import cv2 | |
| import argparse | |
| import torch | |
| import torch.nn.functional as F | |
| import util.io | |
| from torchvision.transforms import Compose | |
| from dpt.models import DPTSegmentationModel | |
| from dpt.transforms import Resize, NormalizeImage, PrepareForNet | |
| def run(input_path, output_path, model_path, model_type="dpt_hybrid", optimize=True): | |
| """Run segmentation network | |
| Args: | |
| input_path (str): path to input folder | |
| output_path (str): path to output folder | |
| model_path (str): path to saved model | |
| """ | |
| print("initialize") | |
| # select device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print("device: %s" % device) | |
| net_w = net_h = 480 | |
| # load network | |
| if model_type == "dpt_large": | |
| model = DPTSegmentationModel( | |
| 150, | |
| path=model_path, | |
| backbone="vitl16_384", | |
| ) | |
| elif model_type == "dpt_hybrid": | |
| model = DPTSegmentationModel( | |
| 150, | |
| path=model_path, | |
| backbone="vitb_rn50_384", | |
| ) | |
| else: | |
| assert ( | |
| False | |
| ), f"model_type '{model_type}' not implemented, use: --model_type [dpt_large|dpt_hybrid]" | |
| transform = Compose( | |
| [ | |
| Resize( | |
| net_w, | |
| net_h, | |
| resize_target=None, | |
| keep_aspect_ratio=True, | |
| ensure_multiple_of=32, | |
| resize_method="minimal", | |
| image_interpolation_method=cv2.INTER_CUBIC, | |
| ), | |
| NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
| PrepareForNet(), | |
| ] | |
| ) | |
| model.eval() | |
| if optimize == True and device == torch.device("cuda"): | |
| model = model.to(memory_format=torch.channels_last) | |
| model = model.half() | |
| model.to(device) | |
| # get input | |
| img_names = glob.glob(os.path.join(input_path, "*")) | |
| num_images = len(img_names) | |
| # create output folder | |
| os.makedirs(output_path, exist_ok=True) | |
| print("start processing") | |
| for ind, img_name in enumerate(img_names): | |
| print(" processing {} ({}/{})".format(img_name, ind + 1, num_images)) | |
| # input | |
| img = util.io.read_image(img_name) | |
| img_input = transform({"image": img})["image"] | |
| # compute | |
| with torch.no_grad(): | |
| sample = torch.from_numpy(img_input).to(device).unsqueeze(0) | |
| if optimize == True and device == torch.device("cuda"): | |
| sample = sample.to(memory_format=torch.channels_last) | |
| sample = sample.half() | |
| out = model.forward(sample) | |
| prediction = torch.nn.functional.interpolate( | |
| out, size=img.shape[:2], mode="bicubic", align_corners=False | |
| ) | |
| prediction = torch.argmax(prediction, dim=1) + 1 | |
| prediction = prediction.squeeze().cpu().numpy() | |
| # output | |
| filename = os.path.join( | |
| output_path, os.path.splitext(os.path.basename(img_name))[0] | |
| ) | |
| util.io.write_segm_img(filename, img, prediction, alpha=0.5) | |
| print("finished") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "-i", "--input_path", default="input", help="folder with input images" | |
| ) | |
| parser.add_argument( | |
| "-o", "--output_path", default="output_semseg", help="folder for output images" | |
| ) | |
| parser.add_argument( | |
| "-m", | |
| "--model_weights", | |
| default=None, | |
| help="path to the trained weights of model", | |
| ) | |
| # 'vit_large', 'vit_hybrid' | |
| parser.add_argument("-t", "--model_type", default="dpt_hybrid", help="model type") | |
| parser.add_argument("--optimize", dest="optimize", action="store_true") | |
| parser.add_argument("--no-optimize", dest="optimize", action="store_false") | |
| parser.set_defaults(optimize=True) | |
| args = parser.parse_args() | |
| default_models = { | |
| "dpt_large": "weights/dpt_large-ade20k-b12dca68.pt", | |
| "dpt_hybrid": "weights/dpt_hybrid-ade20k-53898607.pt", | |
| } | |
| if args.model_weights is None: | |
| args.model_weights = default_models[args.model_type] | |
| # set torch options | |
| torch.backends.cudnn.enabled = True | |
| torch.backends.cudnn.benchmark = True | |
| # compute segmentation maps | |
| run( | |
| args.input_path, | |
| args.output_path, | |
| args.model_weights, | |
| args.model_type, | |
| args.optimize, | |
| ) | |