本文介绍了如何使用tf.estimator.Estimator从检查点进行预测?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我刚刚训练了CNN以识别具有张量流的黑子.我的模型与几乎相同.问题是我无法在任何地方找到关于如何使用训练阶段生成的检查点进行预测的清晰解释.

I just trained a CNN to recognise sunspots with tensorflow. My model is pretty much the same as this. The problem is that I cannot find anywhere a clear explanation on how to make predictions with the checkpoint generated by the training phase.

尝试使用标准的还原方法:

Tried using the standard restore method:

saver = tf.train.import_meta_graph('./model/model.ckpt.meta')
saver.restore(sess,'./model/model.ckpt')

但是我不知道如何运行它.
像这样使用tf.estimator.Estimator.predict()进行了尝试:

but then I cannot figure out how to run it.
Tried using tf.estimator.Estimator.predict() like this:

# Create the Estimator (should reload the last checkpoint but it doesn't)
sunspot_classifier = tf.estimator.Estimator(
    model_fn=cnn_model_fn, model_dir="./model")

# Set up logging for predictions
# Log the values in the "Softmax" tensor with label "probabilities"
tensors_to_log = {"probabilities": "softmax_tensor"}
logging_hook = tf.train.LoggingTensorHook(
    tensors=tensors_to_log, every_n_iter=50)

# predict with the model and print results
pred_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={"x": pred_data},
    shuffle=False)
pred_results = sunspot_classifier.predict(input_fn=pred_input_fn)
print(pred_results)

,但它的作用是吐出<generator object Estimator.predict at 0x10dda6bf8>.如果我使用相同的代码但使用tf.estimator.Estimator.evaluate(),则它的工作原理就像一个魅力(重新加载模型,执行评估并将其发送到TensorBoard).

but what it does is spitting out <generator object Estimator.predict at 0x10dda6bf8>. While if I use the same code but with tf.estimator.Estimator.evaluate() it works like a charm (reloads the model, performs evaluation and sends it to TensorBoard).

我知道有很多类似的问题,但是我真的找不到适合我的方法.

I know there are many similar questions but I couldn't really find the way that worked for me.

推荐答案

sunspot_classifier.predict(input_fn=pred_input_fn)返回生成器.因此pred_results是生成器对象.要从中获取价值,您需要通过next(pred_results)

sunspot_classifier.predict(input_fn=pred_input_fn) returns generator. So pred_results is generator object. To get value from it you need to iterate it by next(pred_results)

解决方法是print(next(pred_results))

这篇关于如何使用tf.estimator.Estimator从检查点进行预测?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

10-12 02:47