Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import os.path as osp | |
| from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union | |
| import mmcv | |
| import mmengine | |
| import numpy as np | |
| from mmengine.dataset import Compose | |
| from mmengine.infer.infer import BaseInferencer, ModelType | |
| from mmengine.model.utils import revert_sync_batchnorm | |
| from mmengine.registry import init_default_scope | |
| from mmengine.structures import InstanceData | |
| from rich.progress import track | |
| from torch import Tensor | |
| from mmocr.utils import ConfigType | |
| InstanceList = List[InstanceData] | |
| InputType = Union[str, np.ndarray] | |
| InputsType = Union[InputType, Sequence[InputType]] | |
| PredType = Union[InstanceData, InstanceList] | |
| ImgType = Union[np.ndarray, Sequence[np.ndarray]] | |
| ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]] | |
| class BaseMMOCRInferencer(BaseInferencer): | |
| """Base inferencer. | |
| Args: | |
| model (str, optional): Path to the config file or the model name | |
| defined in metafile. For example, it could be | |
| "dbnet_resnet18_fpnc_1200e_icdar2015" or | |
| "configs/textdet/dbnet/dbnet_resnet18_fpnc_1200e_icdar2015.py". | |
| If model is not specified, user must provide the | |
| `weights` saved by MMEngine which contains the config string. | |
| Defaults to None. | |
| weights (str, optional): Path to the checkpoint. If it is not specified | |
| and model is a model name of metafile, the weights will be loaded | |
| from metafile. Defaults to None. | |
| device (str, optional): Device to run inference. If None, the available | |
| device will be automatically used. Defaults to None. | |
| scope (str, optional): The scope of the model. Defaults to "mmocr". | |
| """ | |
| preprocess_kwargs: set = set() | |
| forward_kwargs: set = set() | |
| visualize_kwargs: set = { | |
| 'return_vis', 'show', 'wait_time', 'draw_pred', 'pred_score_thr', | |
| 'save_vis' | |
| } | |
| postprocess_kwargs: set = { | |
| 'print_result', 'return_datasample', 'save_pred' | |
| } | |
| loading_transforms: list = ['LoadImageFromFile', 'LoadImageFromNDArray'] | |
| def __init__(self, | |
| model: Union[ModelType, str, None] = None, | |
| weights: Optional[str] = None, | |
| device: Optional[str] = None, | |
| scope: str = 'mmocr') -> None: | |
| # A global counter tracking the number of images given in the form | |
| # of ndarray, for naming the output images | |
| self.num_unnamed_imgs = 0 | |
| init_default_scope(scope) | |
| super().__init__( | |
| model=model, weights=weights, device=device, scope=scope) | |
| self.model = revert_sync_batchnorm(self.model) | |
| def preprocess(self, inputs: InputsType, batch_size: int = 1, **kwargs): | |
| """Process the inputs into a model-feedable format. | |
| Args: | |
| inputs (InputsType): Inputs given by user. | |
| batch_size (int): batch size. Defaults to 1. | |
| Yields: | |
| Any: Data processed by the ``pipeline`` and ``collate_fn``. | |
| """ | |
| chunked_data = self._get_chunk_data(inputs, batch_size) | |
| yield from map(self.collate_fn, chunked_data) | |
| def _get_chunk_data(self, inputs: Iterable, chunk_size: int): | |
| """Get batch data from inputs. | |
| Args: | |
| inputs (Iterable): An iterable dataset. | |
| chunk_size (int): Equivalent to batch size. | |
| Yields: | |
| list: batch data. | |
| """ | |
| inputs_iter = iter(inputs) | |
| while True: | |
| try: | |
| chunk_data = [] | |
| for _ in range(chunk_size): | |
| inputs_ = next(inputs_iter) | |
| pipe_out = self.pipeline(inputs_) | |
| if pipe_out['data_samples'].get('img_path') is None: | |
| pipe_out['data_samples'].set_metainfo( | |
| dict(img_path=f'{self.num_unnamed_imgs}.jpg')) | |
| self.num_unnamed_imgs += 1 | |
| chunk_data.append((inputs_, pipe_out)) | |
| yield chunk_data | |
| except StopIteration: | |
| if chunk_data: | |
| yield chunk_data | |
| break | |
| def __call__(self, | |
| inputs: InputsType, | |
| return_datasamples: bool = False, | |
| batch_size: int = 1, | |
| progress_bar: bool = True, | |
| return_vis: bool = False, | |
| show: bool = False, | |
| wait_time: int = 0, | |
| draw_pred: bool = True, | |
| pred_score_thr: float = 0.3, | |
| out_dir: str = 'results/', | |
| save_vis: bool = False, | |
| save_pred: bool = False, | |
| print_result: bool = False, | |
| **kwargs) -> dict: | |
| """Call the inferencer. | |
| Args: | |
| inputs (InputsType): Inputs for the inferencer. It can be a path | |
| to image / image directory, or an array, or a list of these. | |
| Note: If it's an numpy array, it should be in BGR order. | |
| return_datasamples (bool): Whether to return results as | |
| :obj:`BaseDataElement`. Defaults to False. | |
| batch_size (int): Inference batch size. Defaults to 1. | |
| progress_bar (bool): Whether to show a progress bar. Defaults to | |
| True. | |
| return_vis (bool): Whether to return the visualization result. | |
| Defaults to False. | |
| show (bool): Whether to display the visualization results in a | |
| popup window. Defaults to False. | |
| wait_time (float): The interval of show (s). Defaults to 0. | |
| draw_pred (bool): Whether to draw predicted bounding boxes. | |
| Defaults to True. | |
| pred_score_thr (float): Minimum score of bboxes to draw. | |
| Defaults to 0.3. | |
| out_dir (str): Output directory of results. Defaults to 'results/'. | |
| save_vis (bool): Whether to save the visualization results to | |
| "out_dir". Defaults to False. | |
| save_pred (bool): Whether to save the inference results to | |
| "out_dir". Defaults to False. | |
| print_result (bool): Whether to print the inference result w/o | |
| visualization to the console. Defaults to False. | |
| **kwargs: Other keyword arguments passed to :meth:`preprocess`, | |
| :meth:`forward`, :meth:`visualize` and :meth:`postprocess`. | |
| Each key in kwargs should be in the corresponding set of | |
| ``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs`` | |
| and ``postprocess_kwargs``. | |
| Returns: | |
| dict: Inference and visualization results, mapped from | |
| "predictions" and "visualization". | |
| """ | |
| if (save_vis or save_pred) and not out_dir: | |
| raise ValueError('out_dir must be specified when save_vis or ' | |
| 'save_pred is True!') | |
| if out_dir: | |
| img_out_dir = osp.join(out_dir, 'vis') | |
| pred_out_dir = osp.join(out_dir, 'preds') | |
| else: | |
| img_out_dir, pred_out_dir = '', '' | |
| ( | |
| preprocess_kwargs, | |
| forward_kwargs, | |
| visualize_kwargs, | |
| postprocess_kwargs, | |
| ) = self._dispatch_kwargs( | |
| return_vis=return_vis, | |
| show=show, | |
| wait_time=wait_time, | |
| draw_pred=draw_pred, | |
| pred_score_thr=pred_score_thr, | |
| save_vis=save_vis, | |
| save_pred=save_pred, | |
| print_result=print_result, | |
| **kwargs) | |
| ori_inputs = self._inputs_to_list(inputs) | |
| inputs = self.preprocess( | |
| ori_inputs, batch_size=batch_size, **preprocess_kwargs) | |
| results = {'predictions': [], 'visualization': []} | |
| for ori_inputs, data in track( | |
| inputs, description='Inference', disable=not progress_bar): | |
| preds = self.forward(data, **forward_kwargs) | |
| visualization = self.visualize( | |
| ori_inputs, preds, img_out_dir=img_out_dir, **visualize_kwargs) | |
| batch_res = self.postprocess( | |
| preds, | |
| visualization, | |
| return_datasamples, | |
| pred_out_dir=pred_out_dir, | |
| **postprocess_kwargs) | |
| results['predictions'].extend(batch_res['predictions']) | |
| if return_vis and batch_res['visualization'] is not None: | |
| results['visualization'].extend(batch_res['visualization']) | |
| return results | |
| def _init_pipeline(self, cfg: ConfigType) -> Compose: | |
| """Initialize the test pipeline.""" | |
| pipeline_cfg = cfg.test_dataloader.dataset.pipeline | |
| # For inference, the key of ``instances`` is not used. | |
| if 'meta_keys' in pipeline_cfg[-1]: | |
| pipeline_cfg[-1]['meta_keys'] = tuple( | |
| meta_key for meta_key in pipeline_cfg[-1]['meta_keys'] | |
| if meta_key != 'instances') | |
| # Loading annotations is also not applicable | |
| idx = self._get_transform_idx(pipeline_cfg, 'LoadOCRAnnotations') | |
| if idx != -1: | |
| del pipeline_cfg[idx] | |
| for transform in self.loading_transforms: | |
| load_img_idx = self._get_transform_idx(pipeline_cfg, transform) | |
| if load_img_idx != -1: | |
| pipeline_cfg[load_img_idx]['type'] = 'InferencerLoader' | |
| break | |
| if load_img_idx == -1: | |
| raise ValueError( | |
| f'None of {self.loading_transforms} is found in the test ' | |
| 'pipeline') | |
| return Compose(pipeline_cfg) | |
| def _get_transform_idx(self, pipeline_cfg: ConfigType, name: str) -> int: | |
| """Returns the index of the transform in a pipeline. | |
| If the transform is not found, returns -1. | |
| """ | |
| for i, transform in enumerate(pipeline_cfg): | |
| if transform['type'] == name: | |
| return i | |
| return -1 | |
| def visualize(self, | |
| inputs: InputsType, | |
| preds: PredType, | |
| return_vis: bool = False, | |
| show: bool = False, | |
| wait_time: int = 0, | |
| draw_pred: bool = True, | |
| pred_score_thr: float = 0.3, | |
| save_vis: bool = False, | |
| img_out_dir: str = '') -> Union[List[np.ndarray], None]: | |
| """Visualize predictions. | |
| Args: | |
| inputs (List[Union[str, np.ndarray]]): Inputs for the inferencer. | |
| preds (List[Dict]): Predictions of the model. | |
| return_vis (bool): Whether to return the visualization result. | |
| Defaults to False. | |
| show (bool): Whether to display the image in a popup window. | |
| Defaults to False. | |
| wait_time (float): The interval of show (s). Defaults to 0. | |
| draw_pred (bool): Whether to draw predicted bounding boxes. | |
| Defaults to True. | |
| pred_score_thr (float): Minimum score of bboxes to draw. | |
| Defaults to 0.3. | |
| save_vis (bool): Whether to save the visualization result. Defaults | |
| to False. | |
| img_out_dir (str): Output directory of visualization results. | |
| If left as empty, no file will be saved. Defaults to ''. | |
| Returns: | |
| List[np.ndarray] or None: Returns visualization results only if | |
| applicable. | |
| """ | |
| if self.visualizer is None or not (show or save_vis or return_vis): | |
| return None | |
| if getattr(self, 'visualizer') is None: | |
| raise ValueError('Visualization needs the "visualizer" term' | |
| 'defined in the config, but got None.') | |
| results = [] | |
| for single_input, pred in zip(inputs, preds): | |
| if isinstance(single_input, str): | |
| img_bytes = mmengine.fileio.get(single_input) | |
| img = mmcv.imfrombytes(img_bytes, channel_order='rgb') | |
| elif isinstance(single_input, np.ndarray): | |
| img = single_input.copy()[:, :, ::-1] # to RGB | |
| else: | |
| raise ValueError('Unsupported input type: ' | |
| f'{type(single_input)}') | |
| img_name = osp.splitext(osp.basename(pred.img_path))[0] | |
| if save_vis and img_out_dir: | |
| out_file = osp.splitext(img_name)[0] | |
| out_file = f'{out_file}.jpg' | |
| out_file = osp.join(img_out_dir, out_file) | |
| else: | |
| out_file = None | |
| visualization = self.visualizer.add_datasample( | |
| img_name, | |
| img, | |
| pred, | |
| show=show, | |
| wait_time=wait_time, | |
| draw_gt=False, | |
| draw_pred=draw_pred, | |
| pred_score_thr=pred_score_thr, | |
| out_file=out_file, | |
| ) | |
| results.append(visualization) | |
| return results | |
| def postprocess( | |
| self, | |
| preds: PredType, | |
| visualization: Optional[List[np.ndarray]] = None, | |
| return_datasample: bool = False, | |
| print_result: bool = False, | |
| save_pred: bool = False, | |
| pred_out_dir: str = '', | |
| ) -> Union[ResType, Tuple[ResType, np.ndarray]]: | |
| """Process the predictions and visualization results from ``forward`` | |
| and ``visualize``. | |
| This method should be responsible for the following tasks: | |
| 1. Convert datasamples into a json-serializable dict if needed. | |
| 2. Pack the predictions and visualization results and return them. | |
| 3. Dump or log the predictions. | |
| Args: | |
| preds (List[Dict]): Predictions of the model. | |
| visualization (Optional[np.ndarray]): Visualized predictions. | |
| return_datasample (bool): Whether to use Datasample to store | |
| inference results. If False, dict will be used. | |
| print_result (bool): Whether to print the inference result w/o | |
| visualization to the console. Defaults to False. | |
| save_pred (bool): Whether to save the inference result. Defaults to | |
| False. | |
| pred_out_dir: File to save the inference results w/o | |
| visualization. If left as empty, no file will be saved. | |
| Defaults to ''. | |
| Returns: | |
| dict: Inference and visualization results with key ``predictions`` | |
| and ``visualization``. | |
| - ``visualization`` (Any): Returned by :meth:`visualize`. | |
| - ``predictions`` (dict or DataSample): Returned by | |
| :meth:`forward` and processed in :meth:`postprocess`. | |
| If ``return_datasample=False``, it usually should be a | |
| json-serializable dict containing only basic data elements such | |
| as strings and numbers. | |
| """ | |
| result_dict = {} | |
| results = preds | |
| if not return_datasample: | |
| results = [] | |
| for pred in preds: | |
| result = self.pred2dict(pred) | |
| if save_pred and pred_out_dir: | |
| pred_name = osp.splitext(osp.basename(pred.img_path))[0] | |
| pred_name = f'{pred_name}.json' | |
| pred_out_file = osp.join(pred_out_dir, pred_name) | |
| mmengine.dump(result, pred_out_file) | |
| results.append(result) | |
| # Add img to the results after printing and dumping | |
| result_dict['predictions'] = results | |
| if print_result: | |
| print(result_dict) | |
| result_dict['visualization'] = visualization | |
| return result_dict | |
| def pred2dict(self, data_sample: InstanceData) -> Dict: | |
| """Extract elements necessary to represent a prediction into a | |
| dictionary. | |
| It's better to contain only basic data elements such as strings and | |
| numbers in order to guarantee it's json-serializable. | |
| """ | |
| raise NotImplementedError | |
| def _array2list(self, array: Union[Tensor, np.ndarray, | |
| List]) -> List[float]: | |
| """Convert a tensor or numpy array to a list. | |
| Args: | |
| array (Union[Tensor, np.ndarray]): The array to be converted. | |
| Returns: | |
| List[float]: The converted list. | |
| """ | |
| if isinstance(array, Tensor): | |
| return array.detach().cpu().numpy().tolist() | |
| if isinstance(array, np.ndarray): | |
| return array.tolist() | |
| if isinstance(array, list): | |
| array = [self._array2list(arr) for arr in array] | |
| return array | |