我最近正在基于带有服务器接口的tensorflow CNN,MNIST数据集制作一个项目。

在预测部分,我使用tf.argmax()获得最大的logit,它将是预测值。但是,它返回的值似乎不是正确的答案。

预测函数大约是这样的:

    self.img = tf.reshape(tf.image.convert_image_dtype(img, tf.float32), shape=[1, 28, 28, 1])
    self._create_model()

    saver = tf.train.Saver()
    ckpt = tf.train.get_checkpoint_state('../checkpoints/')
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
    saver.restore(sess, ckpt.model_checkpoint_path)

    pred = tf.nn.softmax(self.logits)
    prediction = tf.argmax(pred, 1)
    logit = sess.run(pred)
    result = sess.run(prediction)[0]
    print(logit)
    print(result)

    return result


结果是:

127.0.0.1 - - [19/Apr/2018 21:35:47] "POST /index.html HTTP/1.1" 200 -
[[ 0.  0.  0.  0.  0.  1.  0.  0.  0.  0.]]
1


如您所见,logit显示最大索引为5,但是tf.argmax()却给了我1。

顺便说一下,my model是基本MNIST CNN模型,如您在链接中所见。

那么这个tf.argmax()函数发生了什么,或者我的代码有什么问题?

最佳答案

由于您的logitpred)和resultprediction[0])来自两个不同的sess.run,所以我想知道运行之间是否存在一些差异。例如,您在图中有一个迭代器,用于将输入发送到模型。通过不同的运行,迭代器将发送不同的数据,从而导致不同的预测。有趣的是,如果将predprediction放在一个相同的sess.run中,像这样:

logit, result = sess.run((pred, prediction))
print(logit)
print(result[0])

关于python - tf.argmax()返回意外结果,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/49923675/

10-12 23:05