本文介绍了使用 tf.estimator.Estimator 加载检查点和微调的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我们正在尝试将基于旧的训练代码转换为更符合 tf.estimator.Estimator 的代码.在初始代码中,我们为目标数据集微调了原始模型.使用 variables_to_restoreinit_fnMonitoredTrainingSession 的组合,在训练开始之前仅从检查点加载一些层.如何使用 tf.estimator.Estimator 方法实现这种权重加载?

We're trying to translate old training code based into a more tf.estimator.Estimator compliant code. In the initial code we fine tune an original model for a target dataset. Only some layers are loaded from the checkpoint before the training takes place using a combination of variables_to_restore and init_fn with the MonitoredTrainingSession.How can one achieve this kind of weight loading with the tf.estimator.Estimator approach ?

推荐答案

import tensorflow as tf    

def model_fn():
  # your model defintion here
  # ...

# specify your saved checkpoint path
checkpoint_path = "model.ckpt"

ws = tf.estimator.WarmStartSettings(ckpt_to_initialize_from=checkpoint_path)
est = tf.estimator.Estimator(model_fn=model_fn, warm_start_from=ws)

这篇关于使用 tf.estimator.Estimator 加载检查点和微调的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

10-12 02:47