我有一些python代码可以使用Tensorflow的TFRecords和Dataset API训练网络。我已经使用tf.Keras.layers建立了网络,可以说这是最简单,最快的方法。方便的功能model_to_estimator()
modelTF = tf.keras.estimator.model_to_estimator(
keras_model=model,
custom_objects=None,
config=run_config,
model_dir=checkPointDirectory
)
将Keras模型转换为估算器,这使我们可以很好地利用Dataset API,并在训练过程中和训练完成时自动将检查点保存到checkPointDirectory。估算器API提供了一些不可估量的功能,例如通过以下方式在多个GPU上自动分配工作负载:
distribution = tf.contrib.distribute.MirroredStrategy()
run_config = tf.estimator.RunConfig(train_distribute=distribution)
现在,对于大型模型和大量数据,使用某种形式的保存模型在训练后执行预测通常很有用。似乎从Tensorflow 1.10开始(请参阅https://github.com/tensorflow/tensorflow/issues/19295),tf.keras.model对象从Tensorflow检查点支持load_weights()。在Tensorflow文档中简要提到了这一点,但在Keras文档中没有提及,我找不到任何显示此示例的人。在一些新的.py中再次定义模型层之后,我尝试了
checkPointPath = os.path.join('.', 'tfCheckPoints', 'keras_model.ckpt.index')
model.load_weights(filepath=checkPointPath, by_name=False)
但这给出了NotImplementedError:
Restoring a name-based tf.train.Saver checkpoint using the object-based restore API. This mode uses global names to match variables, and so is somewhat fragile. It also adds new restore ops to the graph each time it is called when graph building. Prefer re-encoding training checkpoints in the object-based format: run save() on the object-based saver (the same one this message is coming from) and use that checkpoint in the future.
2018-10-01 14:24:49.912087:
Traceback (most recent call last):
File "C:/Users/User/PycharmProjects/python/mercury.classifier reductions/V3.2/wikiTestv3.2/modelEvaluation3.2.py", line 141, in <module>
model.load_weights(filepath=checkPointPath, by_name=False)
File "C:\Users\User\Anaconda3\lib\site-packages\tensorflow\python\keras\engine\network.py", line 1526, in load_weights
checkpointable_utils.streaming_restore(status=status, session=session)
File "C:\Users\User\Anaconda3\lib\site-packages\tensorflow\python\training\checkpointable\util.py", line 880, in streaming_restore
"Streaming restore not supported from name-based checkpoints. File a "
NotImplementedError: Streaming restore not supported from name-based checkpoints. File a feature request if this limitation bothers you.
我想按照警告的建议进行操作,而是使用“基于对象的保护程序”,但是我还没有找到通过传递给estimator.train()的RunConfig来执行此操作的方法。
那么,是否有更好的方法将保存的权重返回到估算器中以用于预测? github线程似乎暗示已经实现了(尽管基于错误,可能与我上面尝试的方式不同)。有没有人在TF检查点上成功使用过load_weights()?我还没有找到有关如何完成此操作的任何教程/示例,因此不胜感激。
最佳答案
我不确定,但是也许您可以将keras_model.ckpt.index
更改为keras_model.ckpt
进行测试。
关于python - 如何从Tensorflow检查点将权重加载到Keras模型,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/52597523/