Spaces:
Runtime error
Runtime error
| # Modified from the implementation of https://huggingface.co/akhaliq | |
| import os | |
| import sys | |
| os.system("git clone https://github.com/NVlabs/GroupViT") | |
| sys.path.insert(0, 'GroupViT') | |
| import os.path as osp | |
| from collections import namedtuple | |
| import gradio as gr | |
| import mmcv | |
| import numpy as np | |
| import torch | |
| from datasets import build_text_transform | |
| from mmcv.cnn.utils import revert_sync_batchnorm | |
| from mmcv.image import tensor2imgs | |
| from mmcv.parallel import collate, scatter | |
| from models import build_model | |
| from omegaconf import read_write | |
| from segmentation.datasets import (COCOObjectDataset, PascalContextDataset, | |
| PascalVOCDataset) | |
| from segmentation.evaluation import (GROUP_PALETTE, build_seg_demo_pipeline, | |
| build_seg_inference) | |
| from utils import get_config, load_checkpoint | |
| import shutil | |
| if not osp.exists('GroupViT/hg_demo'): | |
| shutil.copytree('demo/', 'GroupViT/hg_demo/') | |
| os.chdir('GroupViT') | |
| # checkpoint_url = 'https://github.com/xvjiarui/GroupViT-1/releases/download/v1.0.0/group_vit_gcc_yfcc_30e-74d335e6.pth' | |
| checkpoint_url = 'https://github.com/xvjiarui/GroupViT/releases/download/v1.0.0/group_vit_gcc_yfcc_30e-879422e0.pth' | |
| cfg_path = 'configs/group_vit_gcc_yfcc_30e.yml' | |
| output_dir = 'demo/output' | |
| device = 'cpu' | |
| # vis_modes = ['first_group', 'final_group', 'input_pred_label'] | |
| vis_modes = ['input_pred_label', 'final_group'] | |
| output_labels = ['segmentation map', 'groups'] | |
| dataset_options = ['Pascal VOC', 'Pascal Context', 'COCO'] | |
| examples = [['Pascal VOC', '', 'hg_demo/voc.jpg'], | |
| ['Pascal Context', '', 'hg_demo/ctx.jpg'], | |
| ['COCO', '', 'hg_demo/coco.jpg']] | |
| PSEUDO_ARGS = namedtuple('PSEUDO_ARGS', | |
| ['cfg', 'opts', 'resume', 'vis', 'local_rank']) | |
| args = PSEUDO_ARGS( | |
| cfg=cfg_path, opts=[], resume=checkpoint_url, vis=vis_modes, local_rank=0) | |
| cfg = get_config(args) | |
| with read_write(cfg): | |
| cfg.evaluate.eval_only = True | |
| model = build_model(cfg.model) | |
| model = revert_sync_batchnorm(model) | |
| model.to(device) | |
| model.eval() | |
| load_checkpoint(cfg, model, None, None) | |
| text_transform = build_text_transform(False, cfg.data.text_aug, with_dc=False) | |
| test_pipeline = build_seg_demo_pipeline() | |
| def inference(dataset, additional_classes, input_img): | |
| if dataset == 'voc' or dataset == 'Pascal VOC': | |
| dataset_class = PascalVOCDataset | |
| seg_cfg = 'segmentation/configs/_base_/datasets/pascal_voc12.py' | |
| elif dataset == 'coco' or dataset == 'COCO': | |
| dataset_class = COCOObjectDataset | |
| seg_cfg = 'segmentation/configs/_base_/datasets/coco.py' | |
| elif dataset == 'context' or dataset == 'Pascal Context': | |
| dataset_class = PascalContextDataset | |
| seg_cfg = 'segmentation/configs/_base_/datasets/pascal_context.py' | |
| else: | |
| raise ValueError('Unknown dataset: {}'.format(args.dataset)) | |
| with read_write(cfg): | |
| cfg.evaluate.seg.cfg = seg_cfg | |
| cfg.evaluate.seg.opts = ['test_cfg.mode=whole'] | |
| dataset_cfg = mmcv.Config() | |
| dataset_cfg.CLASSES = list(dataset_class.CLASSES) | |
| dataset_cfg.PALETTE = dataset_class.PALETTE.copy() | |
| if len(additional_classes) > 0: | |
| additional_classes = additional_classes.split(',') | |
| additional_classes = list( | |
| set(additional_classes) - set(dataset_cfg.CLASSES)) | |
| dataset_cfg.CLASSES.extend(additional_classes) | |
| dataset_cfg.PALETTE.extend(GROUP_PALETTE[np.random.choice( | |
| list(range(len(GROUP_PALETTE))), len(additional_classes))]) | |
| seg_model = build_seg_inference(model, dataset_cfg, text_transform, | |
| cfg.evaluate.seg) | |
| device = next(seg_model.parameters()).device | |
| # prepare data | |
| data = dict(img=input_img) | |
| data = test_pipeline(data) | |
| data = collate([data], samples_per_gpu=1) | |
| if next(seg_model.parameters()).is_cuda: | |
| # scatter to specified GPU | |
| data = scatter(data, [device])[0] | |
| else: | |
| data['img_metas'] = [i.data[0] for i in data['img_metas']] | |
| with torch.no_grad(): | |
| result = seg_model(return_loss=False, rescale=False, **data) | |
| img_tensor = data['img'][0] | |
| img_metas = data['img_metas'][0] | |
| imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg']) | |
| assert len(imgs) == len(img_metas) | |
| out_file_dict = dict() | |
| for img, img_meta in zip(imgs, img_metas): | |
| h, w, _ = img_meta['img_shape'] | |
| img_show = img[:h, :w, :] | |
| # ori_h, ori_w = img_meta['ori_shape'][:-1] | |
| # short_side = 448 | |
| # if ori_h > ori_w: | |
| # new_h, new_w = ori_h * short_side//ori_w , short_side | |
| # else: | |
| # new_w, new_h = ori_w * short_side//ori_h , short_side | |
| # img_show = mmcv.imresize(img_show, (new_w, new_h)) | |
| for vis_mode in vis_modes: | |
| out_file = osp.join(output_dir, 'vis_imgs', vis_mode, | |
| f'{vis_mode}.jpg') | |
| seg_model.show_result(img_show, img_tensor.to(device), result, | |
| out_file, vis_mode) | |
| out_file_dict[vis_mode] = out_file | |
| return [out_file_dict[mode] for mode in vis_modes] | |
| title = 'GroupViT' | |
| description = """ | |
| Gradio Demo for GroupViT: Semantic Segmentation Emerges from Text Supervision. \n | |
| You may click on of the examples or upload your own image. \n | |
| GroupViT could perform open vocabulary segmentation, you may input more classes (seperate by comma). | |
| """ | |
| article = """ | |
| <p style='text-align: center'> | |
| <a href='https://arxiv.org/abs/2202.11094' target='_blank'> | |
| GroupViT: Semantic Segmentation Emerges from Text Supervision | |
| </a> | |
| | | |
| <a href='https://github.com/NVlabs/GroupViT' target='_blank'>Github Repo</a></p> | |
| """ | |
| gr.Interface( | |
| inference, | |
| inputs=[ | |
| gr.inputs.Dropdown(dataset_options, type='value', label='Category list'), | |
| gr.inputs.Textbox( | |
| lines=1, placeholder=None, default='', label='More classes'), | |
| gr.inputs.Image(type='filepath') | |
| ], | |
| outputs=[gr.outputs.Image(label=label) for label in output_labels], | |
| title=title, | |
| description=description, | |
| article=article, | |
| examples=examples).launch(enable_queue=True) | |