本文介绍了ValueError:Tensor Tensor(...)不是此图的元素.使用全局变量keras模型时的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用flask运行Web服务器,并且在尝试使用vgg16时出现错误,vgg16是keras的预训练VGG16模型的全局变量.我不知道为什么会出现此错误,或者它是否与Tensorflow后端有关.这是我的代码:

I'm running a web server using flask and the error comes up when I try to use vgg16, which is the global variable for keras' pre-trained VGG16 model. I have no idea why this error rises or whether it has anything to do with the Tensorflow backend.Here is my code:

vgg16 = VGG16(weights='imagenet', include_top=True)

def getVGG16Prediction(img_path):
    global vgg16

    img = image.load_img(img_path, target_size=(224, 224))
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)

    pred = vgg16.predict(x)
    return x, sort(decode_predictions(pred, top=3)[0])

@app.route("/uploadMultipleImages", methods=["POST"])
def uploadMultipleImages():
    uploaded_files = request.files.getlist("file[]")
    for file in uploaded_files:
        path = os.path.join(STATIC_PATH, file.filename)
        pInput, result = getVGG16Prediction(path)

这是完整的错误:

任何评论或建议,我们将不胜感激.谢谢.

Any comment or suggestion is greatly appreciated. Thank you.

推荐答案

此github问题.在此处引用相关部分:

Take a look at avital's answer on this github issue. Quoting the relevant part here:

graph = tf.get_default_graph()

在另一个线程中(或者可能在异步事件处理程序中),执行以下操作:

In the other thread (or perhaps in an asynchronous event handler), do:

global graph
with graph.as_default():
    (... do inference here ...)

我对此做了一些修改,然后将图形存储在应用程序的配置对象中,而不是将其设置为全局对象.

I modified this a bit and stored the graph in my app's config object instead of making it a global.

get_default_graph TensorFlow文档解释了为什么这样做是必要的:

The TensorFlow documentation for get_default_graph explains why this is necessary:

这篇关于ValueError:Tensor Tensor(...)不是此图的元素.使用全局变量keras模型时的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

06-04 08:07