mmdetection的mask输出

推理

利用./demo/image_demo.py进行推理,命令如下:

  1. python demo/image_demo.py demo/demo.jpg rtmdet-ins-s --show
  2. python demo/image_demo.py demo/demo.jpg rtmdet-ins_s_8xb32-300e_coco --show

输出的结果:

  1. ./outputs/preds/demo.json文件:记录了bbox、label和加密的mask
  2. ./outputs/vis/demo.jpg文件:推理出的可视化图

输出推理出的masks

修改./mmdet/apis/det_inferencer.py文件,添加Save_masks函数,并修改pred2dict函数。

    def Save_masks(self, masks, pred_out_dir):
        for i,mask in enumerate(masks):
            # print(mask)
            # print(type(mask))
            # print(mask.shape)
            # print(mask)
            # print(os.getcwd())
            # print('../'+pred_out_dir+'/preds/masks/mask_'+str(i)+'.png')
            # break
            # cv2.imwrite('../'+pred_out_dir+'/masks/mask_'+str(i)+'.png', mask)
            # cv2.imshow('mask', mask.cpu().numpy().astype(np.uint8)*255)
            # cv2.waitKey(0)
            # cv2.destroyAllWindows()
            # cv2.imwrite('../outputs/preds/masks/mask_'+str(i)+'.png', mask.cpu().numpy().astype(np.uint8)*255)
            cv2.imwrite('./outputs/preds/masks/mask_'+str(i)+'.png', mask.cpu().numpy().astype(np.uint8)*255)
            # break
        

    # TODO: The data format and fields saved in json need further discussion.
    #  Maybe should include model name, timestamp, filename, image info etc.
    def pred2dict(self,
                  data_sample: DetDataSample,
                  pred_out_dir: str = '') -> Dict:
        """Extract elements necessary to represent a prediction into a
        dictionary.

        It's better to contain only basic data elements such as strings and
        numbers in order to guarantee it's json-serializable.

        Args:
            data_sample (:obj:`DetDataSample`): Predictions of the model.
            pred_out_dir: Dir to save the inference results w/o
                visualization. If left as empty, no file will be saved.
                Defaults to ''.

        Returns:
            dict: Prediction results.
        """
        is_save_pred = True
        if pred_out_dir == '':
            is_save_pred = False

        if is_save_pred and 'img_path' in data_sample:
            img_path = osp.basename(data_sample.img_path)
            img_path = osp.splitext(img_path)[0]
            out_img_path = osp.join(pred_out_dir, 'preds',
                                    img_path + '_panoptic_seg.png')
            out_json_path = osp.join(pred_out_dir, 'preds', img_path + '.json')
        elif is_save_pred:
            out_img_path = osp.join(
                pred_out_dir, 'preds',
                f'{self.num_predicted_imgs}_panoptic_seg.png')
            out_json_path = osp.join(pred_out_dir, 'preds',
                                     f'{self.num_predicted_imgs}.json')
            self.num_predicted_imgs += 1

        result = {}
        if 'pred_instances' in data_sample:
            masks = data_sample.pred_instances.get('masks')
            pred_instances = data_sample.pred_instances.numpy()
            result = {
                'bboxes': pred_instances.bboxes.tolist(),
                'labels': pred_instances.labels.tolist(),
                'scores': pred_instances.scores.tolist()
            }
            if masks is not None:
                self.Save_masks(masks, pred_out_dir)
                if pred_instances.bboxes.sum() == 0:
                    # Fake bbox, such as the SOLO.
                    bboxes = mask2bbox(masks.cpu()).numpy().tolist()
                    result['bboxes'] = bboxes
                print(masks.shape)
                print(pred_instances.bboxes.shape)
                print(len(pred_instances.labels))
                encode_masks = encode_mask_results(pred_instances.masks)
                for encode_mask in encode_masks:
                    if isinstance(encode_mask['counts'], bytes):
                        encode_mask['counts'] = encode_mask['counts'].decode()
                result['masks'] = encode_masks

        if 'pred_panoptic_seg' in data_sample:
            if VOID is None:
                raise RuntimeError(
                    'panopticapi is not installed, please install it by: '
                    'pip install git+https://github.com/cocodataset/'
                    'panopticapi.git.')

            pan = data_sample.pred_panoptic_seg.sem_seg.cpu().numpy()[0]
            pan[pan % INSTANCE_OFFSET == len(
                self.model.dataset_meta['classes'])] = VOID
            pan = id2rgb(pan).astype(np.uint8)

            if is_save_pred:
                mmcv.imwrite(pan[:, :, ::-1], out_img_path)
                result['panoptic_seg_path'] = out_img_path
            else:
                result['panoptic_seg'] = pan

        if is_save_pred:
            mmengine.dump(result, out_json_path)

        return result
05-08 17:50