我已经使用tensorflow的LinearClassifier()类训练了逻辑回归模型模型,并设置了model_dir参数,该参数指定在模型训练过程中保存检查点的元数据的位置:
# Create temporary directory where metagraphs will evenually be saved
model_dir = tempfile.mkdtemp()
logistic_model = tf.contrib.learn.LinearClassifier(
feature_columns=feature_columns,
n_classes=num_labels, model_dir=model_dir)
我一直在阅读有关从图元还原模型的信息,但是对于使用高级api创建的模型如何还原却一无所获。 LinearClassifier()具有predict()函数,但是我找不到有关如何使用通过检查点元图恢复的模型实例运行预测的任何文档。我将如何去做呢?恢复模型后,我的理解是我正在使用tf.Sess对象,该对象缺少LinearClassifier类的所有内置功能,如下所示:
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
new_saver.restore(sess, 'my-save-dir/my-model-10000')
# Run prediction algorithm...
如何运行高级api所使用的预测算法,以对恢复的模型进行预测?有没有更好的方法来解决这个问题?
感谢您的输入。
最佳答案
LinearClassifier()
具有'model_dir'参数,如果指向已训练模型的时间将恢复该模型。
在培训期间,您需要执行以下操作:
logistic_model = tf.contrib.learn.LinearClassifier(feature_columns=feature_columns, n_classes=num_labels, model_dir=model_dir)
classifier.fit(X_train, y_train, steps=10)
在推论过程中,
LinearClassifier()
将从给定的路径加载经过训练的模型,您无需使用fit()
方法,而是调用predict()
方法:logistic_model = tf.contrib.learn.LinearClassifier(feature_columns=feature_columns, n_classes=num_labels, model_dir=model_dir)
y_pred = classifier.predict(X_test)
关于python - 如何从Tensorflow高级API恢复经过训练的LinearClassifier并进行预测,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/44764887/