enter image description here我首先检测到图像的显着性,然后使用抓取算法对显着性目标进行细分。但是,结果是一个显着图像,但没有分割显着图。错误如下:错误:-5图片必须在功能抓取中输入cv_8uc3类型,这是我的源代码,我该怎么办?

    import tensorflow as tf
    import numpy as np
    import os
    from scipy import misc
    import argparse
    import sys,cv2
    from skimage.io import imread, imsave
    import matplotlib.pyplot as plt
    import matplotlib.image as mpimg

g_mean = np.array(([126.88,120.24,112.19])).reshape([1,1,3])
output_folder = "./test_output"

def rgba2rgb(img):
         if img.ndim == 2:
            img = gray2rgb(img)
         elif img.shape[2] == 4:
            img = img[:, :, :3]
         upper_dim = max(img.shape[:2])
         if upper_dim > args.max_dim:
            img = rescale(img, args.max_dim/float(upper_dim), order=3)
     return img

def largest_contours_rect(saliency):
    contours, hierarchy = cv2.findContours(saliency * 3,cv2.RETR_LIST,cv2.CHAIN_APPROX_SIMPLE)
    contours = sorted(contours, key = cv2.contourArea)
    return cv2.boundingRect(contours[-1])

def refine_saliency_with_grabcut(img, saliency):
    rect = largest_contours_rect(saliency)
    bgdmodel = np.zeros((1, 65),np.float64)
    fgdmodel = np.zeros((1, 65),np.float64)
    saliency[np.where(saliency > 0)] = cv2.GC_FGD
    mask = saliency
    cv2.grabCut(img, mask, rect, bgdmodel, fgdmodel, 1, cv2.GC_INIT_WITH_RECT)
    mask = np.where((mask==2)|(mask==0),0,1).astype('uint8')
    return mask

def backprojection_saliency(img,args):
        saliency =main(args)
        #cv2.imshow("original", saliency)
        #saliency=mpimg.imread('alpha1.png')
        img = cv2.resize(img, (320, 232))
    mask = refine_saliency_with_grabcut(img, saliency)
        #misc.imsave(os.path.join(output_folder,'flowers2.png'),result)
    return mask

def main(args):

    if not os.path.exists(output_folder):
        os.mkdir(output_folder)

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = args.gpu_fraction)
        with tf.Session(config=tf.ConfigProto(gpu_options = gpu_options)) as sess:
        saver = tf.train.import_meta_graph('./meta_graph/my-model.meta')
        saver.restore(sess,tf.train.latest_checkpoint('./salience_model'))
        image_batch = tf.get_collection('image_batch')[0]
        pred_mattes = tf.get_collection('mask')[0]

        if args.rgb_folder:
            rgb_pths = os.listdir(args.rgb_folder)
            for rgb_pth in rgb_pths:
                rgb = misc.imread(os.path.join(args.rgb_folder,rgb_pth))
                if rgb.shape[2]==4:
                    rgb = rgba2rgb(rgb)
                origin_shape = rgb.shape
                rgb = np.expand_dims(misc.imresize(rgb.astype(np.uint8),[320,320,3],interp="nearest").astype(np.float32)-g_mean,0)

                feed_dict = {image_batch:rgb}
                pred_alpha = sess.run(pred_mattes,feed_dict = feed_dict)
                final_alpha = misc.imresize(np.squeeze(pred_alpha),origin_shape)
                misc.imsave(os.path.join(output_folder,rgb_pth),final_alpha)
        else:
            rgb = misc.imread(args.rgb)
            if rgb.shape[2]==4:
                rgb = rgba2rgb(rgb)
            origin_shape = rgb.shape[:2]
            rgb = np.expand_dims(misc.imresize(rgb.astype(np.uint8),[320,320,3],interp="nearest").astype(np.float32)-g_mean,0)

            feed_dict = {image_batch:rgb}
            pred_alpha = sess.run(pred_mattes,feed_dict = feed_dict)
            final_alpha = misc.imresize(np.squeeze(pred_alpha),origin_shape)
            misc.imsave(os.path.join(output_folder,'alpha.png'),final_alpha)
                        #rgbs = mpimg.imread('flower1.jpg')
                        result=refine_saliency_with_grabcut(rgb, final_alpha)
                        misc.imsave(os.path.join(output_folder,'segmentation.png'),result)
                        #cv2.imshow("original", final_alpha)
                        #plt.imshow(final_alpha)
        return final_alpha;

def parse_arguments(argv):
    parser = argparse.ArgumentParser()

    parser.add_argument('--rgb', type=str,
        help='input rgb',default = None)
    parser.add_argument('--rgb_folder', type=str,
        help='input rgb',default = None)
    parser.add_argument('--gpu_fraction', type=float,
        help='how much gpu is needed, usually 4G is enough',default = 1.0)
    return parser.parse_args(argv)


if __name__ == '__main__':
        main(parse_arguments(sys.argv[1:]))``

最佳答案

我使用脱粒的二进制图像找到最大轮廓,然后创建一个蒙版。使用此蒙版进行抓取。

来源:

python - 使用抓取算法来分离显着区域-LMLPHP

结果是这样的:

python - 使用抓取算法来分离显着区域-LMLPHP

#!/usr/bin/python3
# 2017.11.27 15:26:53 CST
# 2017.11.27 16:37:38 CST

import numpy as np
import cv2

## read the image(读取图像)
img = cv2.imread("tt04_flower.png")
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
#displaySplit(img)

## threshed(阈值化)
th, threshed = cv2.threshold(gray, 100, 255, cv2.THRESH_BINARY|cv2.THRESH_OTSU)

## findContours(查找轮廓)
cnts = cv2.findContours(threshed, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)[-2]

## sorted by area(按照面积排序)
cnts = sorted(cnts, key=cv2.contourArea)

## get the maximum's boundinRect(获取最大边缘的外接矩形)
cnt = cnts[-1]
bbox = x,y,w,h = cv2.boundingRect(cnt)

## create mask(创建掩模)
mask = np.ones_like(gray, np.uint8)*cv2.GC_PR_BGD
cv2.drawContours(mask, [cnt], -1, cv2.GC_FGD, -1)

## 使用 grabcut 分割
bgdModel = np.zeros((1, 65), np.float64)
fgdModel = np.zeros((1, 65), np.float64)
rect = bbox

cv2.grabCut(img, mask, rect, bgdModel, fgdModel, 5, cv2.GC_INIT_WITH_MASK)
mask2 = np.where((mask==2)|(mask==0),0,1).astype("uint8")

grabcut = img*mask2[:,:,np.newaxis]

## save and display
cv2.imwrite("flower_res.png", grabcut)
cv2.imshow("(1) source", img)
cv2.imshow("(2) grabcut", grabcut)
cv2.waitKey()

10-08 11:41