Spaces:
Build error
Build error
| # -*- coding: utf-8 -*- | |
| # Copyright (c) Alibaba, Inc. and its affiliates. | |
| # MLSD Line Detection | |
| # From https://github.com/navervision/mlsd | |
| # Apache-2.0 license | |
| import warnings | |
| from abc import ABCMeta | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from scepter.modules.annotator.base_annotator import BaseAnnotator | |
| from scepter.modules.annotator.mlsd.mbv2_mlsd_large import MobileV2_MLSD_Large | |
| from scepter.modules.annotator.mlsd.utils import pred_lines | |
| from scepter.modules.annotator.registry import ANNOTATORS | |
| from scepter.modules.annotator.utils import resize_image, resize_image_ori | |
| from scepter.modules.utils.config import dict_to_yaml | |
| from scepter.modules.utils.distribute import we | |
| from scepter.modules.utils.file_system import FS | |
| class MLSDdetector(BaseAnnotator, metaclass=ABCMeta): | |
| def __init__(self, cfg, logger=None): | |
| super().__init__(cfg, logger=logger) | |
| model = MobileV2_MLSD_Large() | |
| pretrained_model = cfg.get('PRETRAINED_MODEL', None) | |
| if pretrained_model: | |
| with FS.get_from(pretrained_model, wait_finish=True) as local_path: | |
| model.load_state_dict(torch.load(local_path), strict=True) | |
| self.model = model.eval() | |
| self.thr_v = cfg.get('THR_V', 0.1) | |
| self.thr_d = cfg.get('THR_D', 0.1) | |
| def forward(self, image): | |
| if isinstance(image, Image.Image): | |
| image = np.array(image) | |
| elif isinstance(image, torch.Tensor): | |
| image = image.detach().cpu().numpy() | |
| elif isinstance(image, np.ndarray): | |
| image = image.copy() | |
| else: | |
| raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.' | |
| h, w, c = image.shape | |
| image, k = resize_image(image, 1024 if min(h, w) > 1024 else min(h, w)) | |
| img_output = np.zeros_like(image) | |
| try: | |
| lines = pred_lines(image, | |
| self.model, [image.shape[0], image.shape[1]], | |
| self.thr_v, | |
| self.thr_d, | |
| device=we.device_id) | |
| for line in lines: | |
| x_start, y_start, x_end, y_end = [int(val) for val in line] | |
| cv2.line(img_output, (x_start, y_start), (x_end, y_end), | |
| [255, 255, 255], 1) | |
| except Exception as e: | |
| warnings.warn(f'{e}') | |
| return None | |
| img_output = resize_image_ori(h, w, img_output, k) | |
| return img_output[:, :, 0] | |
| def get_config_template(): | |
| return dict_to_yaml('ANNOTATORS', | |
| __class__.__name__, | |
| MLSDdetector.para_dict, | |
| set_name=True) | |