diff --git a/mmocr/apis/inferencers/kie_inferencer.py b/mmocr/apis/inferencers/kie_inferencer.py index c7865d5c9..487e7b6ab 100644 --- a/mmocr/apis/inferencers/kie_inferencer.py +++ b/mmocr/apis/inferencers/kie_inferencer.py @@ -1,7 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy -import os.path as osp -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, Optional, Sequence, Union import mmcv import mmengine @@ -12,7 +11,7 @@ from mmocr.registry import DATASETS from mmocr.structures import KIEDataSample from mmocr.utils import ConfigType -from .base_mmocr_inferencer import BaseMMOCRInferencer, ModelType, PredType +from .base_mmocr_inferencer import BaseMMOCRInferencer, ModelType InputType = Dict InputsType = Sequence[Dict] @@ -186,85 +185,6 @@ def _inputs_to_list(self, inputs: InputsType) -> list: return processed_inputs - def visualize(self, - inputs: InputsType, - preds: PredType, - return_vis: bool = False, - show: bool = False, - wait_time: int = 0, - draw_pred: bool = True, - pred_score_thr: float = 0.3, - save_vis: bool = False, - img_out_dir: str = '') -> Union[List[np.ndarray], None]: - """Visualize predictions. - - Args: - inputs (List[Union[str, np.ndarray]]): Inputs for the inferencer. - preds (List[Dict]): Predictions of the model. - return_vis (bool): Whether to return the visualization result. - Defaults to False. - show (bool): Whether to display the image in a popup window. - Defaults to False. - wait_time (float): The interval of show (s). Defaults to 0. - draw_pred (bool): Whether to draw predicted bounding boxes. - Defaults to True. - pred_score_thr (float): Minimum score of bboxes to draw. - Defaults to 0.3. - save_vis (bool): Whether to save the visualization result. Defaults - to False. - img_out_dir (str): Output directory of visualization results. - If left as empty, no file will be saved. Defaults to ''. - - Returns: - List[np.ndarray] or None: Returns visualization results only if - applicable. - """ - if self.visualizer is None or not (show or save_vis or return_vis): - return None - - if getattr(self, 'visualizer') is None: - raise ValueError('Visualization needs the "visualizer" term' - 'defined in the config, but got None.') - - results = [] - - for single_input, pred in zip(inputs, preds): - assert 'img' in single_input or 'img_shape' in single_input - if 'img' in single_input: - if isinstance(single_input['img'], str): - img_bytes = mmengine.fileio.get(single_input['img']) - img = mmcv.imfrombytes(img_bytes, channel_order='rgb') - elif isinstance(single_input['img'], np.ndarray): - img = single_input['img'].copy()[:, :, ::-1] # To RGB - elif 'img_shape' in single_input: - img = np.zeros(single_input['img_shape'], dtype=np.uint8) - else: - raise ValueError('Input does not contain either "img" or ' - '"img_shape"') - img_name = osp.splitext(osp.basename(pred.img_path))[0] - - if save_vis and img_out_dir: - out_file = osp.splitext(img_name)[0] - out_file = f'{out_file}.jpg' - out_file = osp.join(img_out_dir, out_file) - else: - out_file = None - - visualization = self.visualizer.add_datasample( - img_name, - img, - pred, - show=show, - wait_time=wait_time, - draw_gt=False, - draw_pred=draw_pred, - pred_score_thr=pred_score_thr, - out_file=out_file, - ) - results.append(visualization) - - return results - def pred2dict(self, data_sample: KIEDataSample) -> Dict: """Extract elements necessary to represent a prediction into a dictionary. It's better to contain only basic data elements such as diff --git a/mmocr/apis/inferencers/mmocr_inferencer.py b/mmocr/apis/inferencers/mmocr_inferencer.py index c531be35c..f87447136 100644 --- a/mmocr/apis/inferencers/mmocr_inferencer.py +++ b/mmocr/apis/inferencers/mmocr_inferencer.py @@ -143,9 +143,8 @@ def forward(self, kie_batch_size = batch_size if self.mode == 'rec': # The extra list wrapper here is for the ease of postprocessing - self.rec_inputs = inputs predictions = self.textrec_inferencer( - self.rec_inputs, + inputs, return_datasamples=True, batch_size=rec_batch_size, **forward_kwargs)['predictions'] @@ -161,20 +160,20 @@ def forward(self, for img, det_data_sample in zip( self._inputs2ndarrray(inputs), result['det']): det_pred = det_data_sample.pred_instances - self.rec_inputs = [] + rec_inputs = [] for polygon in det_pred['polygons']: # Roughly convert the polygon to a quadangle with # 4 points quad = bbox2poly(poly2bbox(polygon)).tolist() - self.rec_inputs.append(crop_img(img, quad)) + rec_inputs.append(crop_img(img, quad)) result['rec'].append( self.textrec_inferencer( - self.rec_inputs, + rec_inputs, return_datasamples=True, batch_size=rec_batch_size, **forward_kwargs)['predictions']) if self.mode == 'det_rec_kie': - self.kie_inputs = [] + kie_inputs = [] # TODO: when the det output is empty, kie will fail # as no gt-instances can be provided. It's a known # issue but cannot be solved elegantly since we support @@ -190,9 +189,9 @@ def forward(self, dict( bbox=poly2bbox(polygon), text=rec_data_sample.pred_text.item)) - self.kie_inputs.append(kie_input) + kie_inputs.append(kie_input) result['kie'] = self.kie_inferencer( - self.kie_inputs, + kie_inputs, return_datasamples=True, batch_size=kie_batch_size, **forward_kwargs)['predictions'] @@ -223,7 +222,7 @@ def visualize(self, inputs: InputsType, preds: PredType, """ if 'kie' in self.mode: - return self.kie_inferencer.visualize(self.kie_inputs, preds['kie'], + return self.kie_inferencer.visualize(inputs, preds['kie'], **kwargs) elif 'rec' in self.mode: if 'det' in self.mode: @@ -232,7 +231,7 @@ def visualize(self, inputs: InputsType, preds: PredType, **kwargs) else: return self.textrec_inferencer.visualize( - self.rec_inputs, preds['rec'][0], **kwargs) + inputs, preds['rec'][0], **kwargs) else: return self.textdet_inferencer.visualize(inputs, preds['det'], **kwargs) diff --git a/mmocr/visualization/kie_visualizer.py b/mmocr/visualization/kie_visualizer.py index 753bac2e9..b882f9913 100644 --- a/mmocr/visualization/kie_visualizer.py +++ b/mmocr/visualization/kie_visualizer.py @@ -7,6 +7,7 @@ import torch from matplotlib.collections import PatchCollection from matplotlib.patches import FancyArrow +from mmengine.utils.manager import _accquire_lock, _release_lock from mmengine.visualization import Visualizer from mmengine.visualization.utils import (check_type, check_type_and_length, color_val_matplotlib, tensor2ndarray, @@ -233,6 +234,7 @@ def add_datasample(self, out_file (str): Path to output file. Defaults to None. step (int): Global step value to record. Defaults to 0. """ + _accquire_lock() cat_images = list() if draw_gt: @@ -274,7 +276,10 @@ def add_datasample(self, mmcv.imwrite(cat_images[..., ::-1], out_file) self.set_image(cat_images) - return self.get_image() + drawn_rgb_image = self.get_image() + _release_lock() + + return drawn_rgb_image def draw_arrows(self, x_data: Union[np.ndarray, torch.Tensor], diff --git a/mmocr/visualization/textdet_visualizer.py b/mmocr/visualization/textdet_visualizer.py index 8b3f54da1..dcfaf1d9a 100644 --- a/mmocr/visualization/textdet_visualizer.py +++ b/mmocr/visualization/textdet_visualizer.py @@ -4,6 +4,7 @@ import mmcv import numpy as np import torch +from mmengine.utils.manager import _accquire_lock, _release_lock from mmocr.registry import VISUALIZERS from mmocr.structures import TextDetDataSample @@ -147,6 +148,7 @@ def add_datasample(self, and masks. Defaults to 0.3. step (int): Global step value to record. Defaults to 0. """ + _accquire_lock() cat_images = [] if data_sample is not None: if draw_gt and 'gt_instances' in data_sample: @@ -191,4 +193,7 @@ def add_datasample(self, mmcv.imwrite(cat_images[..., ::-1], out_file) self.set_image(cat_images) - return self.get_image() + drawn_rgb_image = self.get_image() + _release_lock() + + return drawn_rgb_image diff --git a/mmocr/visualization/textrecog_visualizer.py b/mmocr/visualization/textrecog_visualizer.py index d2f529b47..d2114d929 100644 --- a/mmocr/visualization/textrecog_visualizer.py +++ b/mmocr/visualization/textrecog_visualizer.py @@ -4,6 +4,7 @@ import cv2 import mmcv import numpy as np +from mmengine.utils.manager import _accquire_lock, _release_lock from mmocr.registry import VISUALIZERS from mmocr.structures import TextRecogDataSample @@ -112,6 +113,7 @@ def add_datasample(self, pred_score_thr (float): Threshold of prediction score. It's not used in this function. Defaults to None. """ + _accquire_lock() height, width = image.shape[:2] resize_height = 64 resize_width = int(1.0 * width / height * resize_height) @@ -141,4 +143,7 @@ def add_datasample(self, mmcv.imwrite(cat_images[..., ::-1], out_file) self.set_image(cat_images) - return self.get_image() + drawn_rgb_image = self.get_image() + _release_lock() + + return drawn_rgb_image diff --git a/mmocr/visualization/textspotting_visualizer.py b/mmocr/visualization/textspotting_visualizer.py index bd4038c35..28effa1b4 100644 --- a/mmocr/visualization/textspotting_visualizer.py +++ b/mmocr/visualization/textspotting_visualizer.py @@ -4,6 +4,7 @@ import mmcv import numpy as np import torch +from mmengine.utils.manager import _accquire_lock, _release_lock from mmocr.registry import VISUALIZERS from mmocr.structures import TextDetDataSample @@ -103,6 +104,7 @@ def add_datasample(self, and masks. Defaults to 0.3. step (int): Global step value to record. Defaults to 0. """ + _accquire_lock() cat_images = [] if data_sample is not None: @@ -141,4 +143,7 @@ def add_datasample(self, mmcv.imwrite(cat_images[..., ::-1], out_file) self.set_image(cat_images) - return self.get_image() + drawn_rgb_image = self.get_image() + _release_lock() + + return drawn_rgb_image