本文介绍了使用 tf.estimator.Estimator 加载检查点和微调的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!
问题描述
我们正在尝试将基于旧的训练代码转换为更符合 tf.estimator.Estimator 的代码.在初始代码中,我们为目标数据集微调了原始模型.使用 variables_to_restore 和 init_fn 与 MonitoredTrainingSession 的组合,在训练开始之前仅从检查点加载一些层.如何使用 tf.estimator.Estimator 方法实现这种权重加载?
We're trying to translate old training code based into a more tf.estimator.Estimator compliant code. In the initial code we fine tune an original model for a target dataset. Only some layers are loaded from the checkpoint before the training takes place using a combination of variables_to_restore and init_fn with the MonitoredTrainingSession.How can one achieve this kind of weight loading with the tf.estimator.Estimator approach ?
推荐答案
import tensorflow as tf
def model_fn():
# your model defintion here
# ...
# specify your saved checkpoint path
checkpoint_path = "model.ckpt"
ws = tf.estimator.WarmStartSettings(ckpt_to_initialize_from=checkpoint_path)
est = tf.estimator.Estimator(model_fn=model_fn, warm_start_from=ws)
这篇关于使用 tf.estimator.Estimator 加载检查点和微调的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!