mmdetection的mask输出
推理
利用./demo/image_demo.py进行推理,命令如下:
- python demo/image_demo.py demo/demo.jpg rtmdet-ins-s --show
- python demo/image_demo.py demo/demo.jpg rtmdet-ins_s_8xb32-300e_coco --show
输出的结果:
- ./outputs/preds/demo.json文件:记录了bbox、label和加密的mask
- ./outputs/vis/demo.jpg文件:推理出的可视化图
输出推理出的masks
修改./mmdet/apis/det_inferencer.py文件,添加Save_masks函数,并修改pred2dict函数。
def Save_masks(self, masks, pred_out_dir):
for i,mask in enumerate(masks):
# print(mask)
# print(type(mask))
# print(mask.shape)
# print(mask)
# print(os.getcwd())
# print('../'+pred_out_dir+'/preds/masks/mask_'+str(i)+'.png')
# break
# cv2.imwrite('../'+pred_out_dir+'/masks/mask_'+str(i)+'.png', mask)
# cv2.imshow('mask', mask.cpu().numpy().astype(np.uint8)*255)
# cv2.waitKey(0)
# cv2.destroyAllWindows()
# cv2.imwrite('../outputs/preds/masks/mask_'+str(i)+'.png', mask.cpu().numpy().astype(np.uint8)*255)
cv2.imwrite('./outputs/preds/masks/mask_'+str(i)+'.png', mask.cpu().numpy().astype(np.uint8)*255)
# break
# TODO: The data format and fields saved in json need further discussion.
# Maybe should include model name, timestamp, filename, image info etc.
def pred2dict(self,
data_sample: DetDataSample,
pred_out_dir: str = '') -> Dict:
"""Extract elements necessary to represent a prediction into a
dictionary.
It's better to contain only basic data elements such as strings and
numbers in order to guarantee it's json-serializable.
Args:
data_sample (:obj:`DetDataSample`): Predictions of the model.
pred_out_dir: Dir to save the inference results w/o
visualization. If left as empty, no file will be saved.
Defaults to ''.
Returns:
dict: Prediction results.
"""
is_save_pred = True
if pred_out_dir == '':
is_save_pred = False
if is_save_pred and 'img_path' in data_sample:
img_path = osp.basename(data_sample.img_path)
img_path = osp.splitext(img_path)[0]
out_img_path = osp.join(pred_out_dir, 'preds',
img_path + '_panoptic_seg.png')
out_json_path = osp.join(pred_out_dir, 'preds', img_path + '.json')
elif is_save_pred:
out_img_path = osp.join(
pred_out_dir, 'preds',
f'{self.num_predicted_imgs}_panoptic_seg.png')
out_json_path = osp.join(pred_out_dir, 'preds',
f'{self.num_predicted_imgs}.json')
self.num_predicted_imgs += 1
result = {}
if 'pred_instances' in data_sample:
masks = data_sample.pred_instances.get('masks')
pred_instances = data_sample.pred_instances.numpy()
result = {
'bboxes': pred_instances.bboxes.tolist(),
'labels': pred_instances.labels.tolist(),
'scores': pred_instances.scores.tolist()
}
if masks is not None:
self.Save_masks(masks, pred_out_dir)
if pred_instances.bboxes.sum() == 0:
# Fake bbox, such as the SOLO.
bboxes = mask2bbox(masks.cpu()).numpy().tolist()
result['bboxes'] = bboxes
print(masks.shape)
print(pred_instances.bboxes.shape)
print(len(pred_instances.labels))
encode_masks = encode_mask_results(pred_instances.masks)
for encode_mask in encode_masks:
if isinstance(encode_mask['counts'], bytes):
encode_mask['counts'] = encode_mask['counts'].decode()
result['masks'] = encode_masks
if 'pred_panoptic_seg' in data_sample:
if VOID is None:
raise RuntimeError(
'panopticapi is not installed, please install it by: '
'pip install git+https://github.com/cocodataset/'
'panopticapi.git.')
pan = data_sample.pred_panoptic_seg.sem_seg.cpu().numpy()[0]
pan[pan % INSTANCE_OFFSET == len(
self.model.dataset_meta['classes'])] = VOID
pan = id2rgb(pan).astype(np.uint8)
if is_save_pred:
mmcv.imwrite(pan[:, :, ::-1], out_img_path)
result['panoptic_seg_path'] = out_img_path
else:
result['panoptic_seg'] = pan
if is_save_pred:
mmengine.dump(result, out_json_path)
return result