enter image description here我首先检测到图像的显着性,然后使用抓取算法对显着性目标进行细分。但是,结果是一个显着图像,但没有分割显着图。错误如下:错误:-5图片必须在功能抓取中输入cv_8uc3类型,这是我的源代码,我该怎么办?
import tensorflow as tf
import numpy as np
import os
from scipy import misc
import argparse
import sys,cv2
from skimage.io import imread, imsave
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
g_mean = np.array(([126.88,120.24,112.19])).reshape([1,1,3])
output_folder = "./test_output"
def rgba2rgb(img):
if img.ndim == 2:
img = gray2rgb(img)
elif img.shape[2] == 4:
img = img[:, :, :3]
upper_dim = max(img.shape[:2])
if upper_dim > args.max_dim:
img = rescale(img, args.max_dim/float(upper_dim), order=3)
return img
def largest_contours_rect(saliency):
contours, hierarchy = cv2.findContours(saliency * 3,cv2.RETR_LIST,cv2.CHAIN_APPROX_SIMPLE)
contours = sorted(contours, key = cv2.contourArea)
return cv2.boundingRect(contours[-1])
def refine_saliency_with_grabcut(img, saliency):
rect = largest_contours_rect(saliency)
bgdmodel = np.zeros((1, 65),np.float64)
fgdmodel = np.zeros((1, 65),np.float64)
saliency[np.where(saliency > 0)] = cv2.GC_FGD
mask = saliency
cv2.grabCut(img, mask, rect, bgdmodel, fgdmodel, 1, cv2.GC_INIT_WITH_RECT)
mask = np.where((mask==2)|(mask==0),0,1).astype('uint8')
return mask
def backprojection_saliency(img,args):
saliency =main(args)
#cv2.imshow("original", saliency)
#saliency=mpimg.imread('alpha1.png')
img = cv2.resize(img, (320, 232))
mask = refine_saliency_with_grabcut(img, saliency)
#misc.imsave(os.path.join(output_folder,'flowers2.png'),result)
return mask
def main(args):
if not os.path.exists(output_folder):
os.mkdir(output_folder)
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = args.gpu_fraction)
with tf.Session(config=tf.ConfigProto(gpu_options = gpu_options)) as sess:
saver = tf.train.import_meta_graph('./meta_graph/my-model.meta')
saver.restore(sess,tf.train.latest_checkpoint('./salience_model'))
image_batch = tf.get_collection('image_batch')[0]
pred_mattes = tf.get_collection('mask')[0]
if args.rgb_folder:
rgb_pths = os.listdir(args.rgb_folder)
for rgb_pth in rgb_pths:
rgb = misc.imread(os.path.join(args.rgb_folder,rgb_pth))
if rgb.shape[2]==4:
rgb = rgba2rgb(rgb)
origin_shape = rgb.shape
rgb = np.expand_dims(misc.imresize(rgb.astype(np.uint8),[320,320,3],interp="nearest").astype(np.float32)-g_mean,0)
feed_dict = {image_batch:rgb}
pred_alpha = sess.run(pred_mattes,feed_dict = feed_dict)
final_alpha = misc.imresize(np.squeeze(pred_alpha),origin_shape)
misc.imsave(os.path.join(output_folder,rgb_pth),final_alpha)
else:
rgb = misc.imread(args.rgb)
if rgb.shape[2]==4:
rgb = rgba2rgb(rgb)
origin_shape = rgb.shape[:2]
rgb = np.expand_dims(misc.imresize(rgb.astype(np.uint8),[320,320,3],interp="nearest").astype(np.float32)-g_mean,0)
feed_dict = {image_batch:rgb}
pred_alpha = sess.run(pred_mattes,feed_dict = feed_dict)
final_alpha = misc.imresize(np.squeeze(pred_alpha),origin_shape)
misc.imsave(os.path.join(output_folder,'alpha.png'),final_alpha)
#rgbs = mpimg.imread('flower1.jpg')
result=refine_saliency_with_grabcut(rgb, final_alpha)
misc.imsave(os.path.join(output_folder,'segmentation.png'),result)
#cv2.imshow("original", final_alpha)
#plt.imshow(final_alpha)
return final_alpha;
def parse_arguments(argv):
parser = argparse.ArgumentParser()
parser.add_argument('--rgb', type=str,
help='input rgb',default = None)
parser.add_argument('--rgb_folder', type=str,
help='input rgb',default = None)
parser.add_argument('--gpu_fraction', type=float,
help='how much gpu is needed, usually 4G is enough',default = 1.0)
return parser.parse_args(argv)
if __name__ == '__main__':
main(parse_arguments(sys.argv[1:]))``
最佳答案
我使用脱粒的二进制图像找到最大轮廓,然后创建一个蒙版。使用此蒙版进行抓取。
来源:
结果是这样的:
#!/usr/bin/python3
# 2017.11.27 15:26:53 CST
# 2017.11.27 16:37:38 CST
import numpy as np
import cv2
## read the image(读取图像)
img = cv2.imread("tt04_flower.png")
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
#displaySplit(img)
## threshed(阈值化)
th, threshed = cv2.threshold(gray, 100, 255, cv2.THRESH_BINARY|cv2.THRESH_OTSU)
## findContours(查找轮廓)
cnts = cv2.findContours(threshed, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)[-2]
## sorted by area(按照面积排序)
cnts = sorted(cnts, key=cv2.contourArea)
## get the maximum's boundinRect(获取最大边缘的外接矩形)
cnt = cnts[-1]
bbox = x,y,w,h = cv2.boundingRect(cnt)
## create mask(创建掩模)
mask = np.ones_like(gray, np.uint8)*cv2.GC_PR_BGD
cv2.drawContours(mask, [cnt], -1, cv2.GC_FGD, -1)
## 使用 grabcut 分割
bgdModel = np.zeros((1, 65), np.float64)
fgdModel = np.zeros((1, 65), np.float64)
rect = bbox
cv2.grabCut(img, mask, rect, bgdModel, fgdModel, 5, cv2.GC_INIT_WITH_MASK)
mask2 = np.where((mask==2)|(mask==0),0,1).astype("uint8")
grabcut = img*mask2[:,:,np.newaxis]
## save and display
cv2.imwrite("flower_res.png", grabcut)
cv2.imshow("(1) source", img)
cv2.imshow("(2) grabcut", grabcut)
cv2.waitKey()