问题描述
我正在使用 Tensorflow 1.4.
I am working with Tensorflow 1.4.
我创建了一个自定义的 tf.estimator 来进行分类,如下所示:
I created a custom tf.estimator in order to do classification, like this:
def model_fn():
# Some operations here
[...]
return tf.estimator.EstimatorSpec(mode=mode,
predictions={"Preds": predictions},
loss=cost,
train_op=loss,
eval_metric_ops=eval_metric_ops,
training_hooks=[summary_hook])
my_estimator = tf.estimator.Estimator(model_fn=model_fn,
params=model_params,
model_dir='/my/directory')
我可以轻松训练它:
input_fn = create_train_input_fn(path=train_files)
my_estimator.train(input_fn=input_fn)
其中input_fn是一个从tfrecords文件读取数据的函数,使用tf.data.Dataset API.
where input_fn is a function that reads data from tfrecords files, with the tf.data.Dataset API.
当我从 tfrecords 文件中读取数据时,我在进行预测时在内存中没有标签.
As I am reading from tfrecords files, I don't have labels in memory when I am making predictions.
我的问题是,如何通过 predict() 方法或 evaluate() 方法返回预测和标签?
My question is, how can I have predictions AND labels returned, either by the predict() method or the evaluate() method?
似乎没有办法两者兼得.predict() 无法访问 (?) 标签,并且无法使用 evaluate() 方法访问 predictions 字典.
It seems there is no way to have both. predict() does not have access (?) to labels, and it is not possible to access the predictions dictionary with the evaluate() method.
推荐答案
完成培训后,在 '/my/directory'
中,您有一堆检查点文件.
After you finished your training, in '/my/directory'
you have a bunch of checkpoint files.
您需要再次设置输入管道,手动加载其中一个文件,然后开始循环存储预测和标签的批次:
You need to set up your input pipeline again, manually load one of those files, then start looping through your batches storing the predictions and the labels:
# Rebuild the input pipeline
input_fn = create_eval_input_fn(path=eval_files)
features, labels = input_fn()
# Rebuild the model
predictions = model_fn(features, labels, tf.estimator.ModeKeys.EVAL).predictions
# Manually load the latest checkpoint
saver = tf.train.Saver()
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state('/my/directory')
saver.restore(sess, ckpt.model_checkpoint_path)
# Loop through the batches and store predictions and labels
prediction_values = []
label_values = []
while True:
try:
preds, lbls = sess.run([predictions, labels])
prediction_values += preds
label_values += lbls
except tf.errors.OutOfRangeError:
break
# store prediction_values and label_values somewhere
更新:改为直接使用您已有的model_fn
函数.
Update: changed to use directly the model_fn
function you already have.
这篇关于如何使用 tf.estimator(使用 predict 或 eval 方法)返回预测和标签?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!