在我当前的项目中,我训练一个模型并每 100 个迭代步骤保存一次检查点。检查点文件都保存在同一目录中(model.ckpt-100、model.ckpt-200、model.ckpt-300 等)。之后,我想根据所有保存的检查点的验证数据来评估模型,而不仅仅是最新的检查点。
目前,我用于恢复检查点文件的一段代码如下所示:
ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
ckpt_list = saver.last_checkpoints
print(ckpt_list)
if ckpt and ckpt.model_checkpoint_path:
print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path)
# extract global_step from it.
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
print('Succesfully loaded model from %s at step=%s.' %
(ckpt.model_checkpoint_path, global_step))
else:
print('No checkpoint file found')
return
但是,这只会恢复最近保存的检查点文件。那么如何在所有保存的检查点文件上编写一个循环呢?我尝试使用 saver.last_checkpoints 获取检查点文件列表,但是,返回的列表为空。
任何帮助将不胜感激,提前致谢!
最佳答案
最快的解决方案:tensor2tensor
有一个模块 utils
和一个脚本 avg_checkpoints.py
将平均权重保存在一个新的检查点。假设您有一个想要平均的检查点列表。您有 2 个使用选项:
TRAIN_DIR=path_to_your_model_folder
FNC_PATH=path_to_tensor2tensor+'/utils/avg.checkpoints.py'
CKPTS=model.ckpt-10000,model.ckpt-20000,model.ckpt-100000
python3 $FNC_PATH --prefix=$TRAIN_DIR --checkpoints=$CKPTS \
--output_path="${TRAIN_DIR}averaged.ckpt"
os.system
):import os
os.system(
"python3 "+FNC_DIR+" --prefix="+TRAIN_DIR+" --checkpoints="+CKPTS+
" --output_path="+TRAIN_DIR+"averaged.ckpt"
)
作为指定检查点列表并使用
--checkpoints
参数的替代方法,您可以仅使用 --num_checkpoints=10
对最后 10 个检查点进行平均。如果您不想依赖
tensor2tensor
:这是一个不依赖于
tensor2tensor
的代码片段,但仍然可以平均 可变数量的检查点 (与 ted 的答案相反)。假设 steps
是应该合并的检查点列表(例如 [10000, 20000, 30000, 40000]
)。然后:
# Restore all sessions and save the weight matrices
values = []
for step in steps:
tf.reset_default_graph()
path = model_path+'/model.ckpt-'+str(step)
with tf.Session() as sess:
saver = tf.train.import_meta_graph(path+'.meta')
saver.restore(sess, path)
values.append(sess.run(tf.all_variables()))
# Average weights
variables = tf.all_variables()
all_assign = []
for ind, var in enumerate(variables):
weights = np.concatenate(
[np.expand_dims(w[ind],axis=0) for w in values],
axis=0
)
all_assign.append(tf.assign(var, np.mean(weights, axis=0))
然后你可以继续,但是你喜欢,例如保存平均检查点:# Now save the new values into a separate checkpoint
with tf.Session() as sess_test:
sess_test.run(all_assign)
saver = tf.train.Saver()
saver.save(sess_test, model_path+'/average_'+str(num_checkpoints))
关于python - tensorflow:在多个检查点上运行模型评估,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/42532061/