中加载旧的检查点

中加载旧的检查点

本文介绍了在 tensorflow 中加载旧的检查点的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我使用 Tensorflow r0.12 训练了一些模型并保存了它.后来我更新到了r1.0.1.一些模型加载没有任何问题,但如果模型中有 RNN 单元,加载失败,Key layer-5/bidirectional_rnn/bw/multi_rnn_cell/cell_1/basic_rnn_cell/biases not found in checkpoint.此外,如果我检查 model.index 文件,我会在那里看到类似的条目,例如:5/BiRNN/BW/MultiRNNCell/Cell0/BasicRNNCell/Linear/Bias.

I trained some models using Tensorflow r0.12 and saved it. Later I updated to r1.0.1. Some models are loading without any problems, yet if the model has RNN cells in it, loading fails with Key layer-5/bidirectional_rnn/bw/multi_rnn_cell/cell_1/basic_rnn_cell/biases not found in checkpoint.Also if I check model.index file I see similar entries there, for example: 5/BiRNN/BW/MultiRNNCell/Cell0/BasicRNNCell/Linear/Bias.

带有 RNN 单元的包现在在 tf.contrib.rnn(在 0.12 中是 tf.nn.rnn_cell),所以我认为一些命名已经改变.

Package with RNN cells is now in tf.contrib.rnn (it was tf.nn.rnn_cell in 0.12), so I think some naming has been changed.

问题是:有没有办法加载我的模型,重新映射其张量并保存,以便张量名称与 r1.0 兼容?

The question is:Is there a way to load my model, re-map its tensors and save so that tensor names would be compatible with r1.0?

附言如果有帮助,我还有 model.meta 文件.

P.S. I also have model.meta file if that helps.

谢谢!

推荐答案

如果有人遇到同样的问题,这里是我使用的解决方案.它是 tensorflow.python.toolsinspect_checkpoint.py 中张量打印函数的修改版本.

If someone gets the same problem, here is the solution I used. It is a modified version of tensor printing function from inspect_checkpoint.py in tensorflow.python.tools.


def resave_tensors(file_name, rename_map, dry_run=False):
    """
    Updates checkpoint by renaming tensors in it.
    :param file_name: Filename with checkpoint.
    :param rename_map: Map from old names to new ones
    :param dry_run: If True, just print new tensors.
    """
    renames_count = 0
    reader = pywrap_tensorflow.NewCheckpointReader(file_name)
    var_to_shape_map = reader.get_variable_to_shape_map()
    for key in sorted(var_to_shape_map):
        print("tensor_name: ", key)
        tensor_val = reader.get_tensor(key)
        print('shape: {}'.format(tensor_val.shape))
        if key in rename_map:
            renames_count += 1
            key = rename_map[key]
        tf.Variable(tensor_val, dtype=tensor_val.dtype, name=key)
    saver = tf.train.Saver()
    if not dry_run:
        with tf.Session() as session:
            session.run(tf.global_variables_initializer())
            saver.save(session, file_name)
    print('Renamed vars: {}'.format(renames_count))

这篇关于在 tensorflow 中加载旧的检查点的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

08-05 13:29