问题描述
我有 inception_resnet_v2_2016_08_30.ckpt
文件,它是一个预训练的初始模型.我想使用
saver.restore(sess, ckpt_filename)
但为此,我将需要编写训练此模型时使用的变量集.我在哪里可以找到这些(脚本或详细说明)?
首先你要了解内存中的网络架构.您可以从 此处 获取网络架构>
一旦你有了这个程序,使用以下方法来使用模型:
from inception_resnet_v2 import inception_resnet_v2, inception_resnet_v2_arg_scope高度 = 299宽度 = 299频道 = 3X = tf.placeholder(tf.float32, shape=[None, height, width, channels])使用 slim.arg_scope(inception_resnet_v2_arg_scope()):logits, end_points = inception_resnet_v2(X, num_classes=1001,is_training=False)
这样你就拥有了内存中的所有网络,现在你可以使用 tf.train.saver 使用检查点文件(ckpt)初始化网络:
saver = tf.train.Saver()sess = tf.Session()saver.restore(sess, "/home/pramod/Downloads/inception_resnet_v2_2016_08_30.ckpt")
如果你想做瓶子提取,它很简单,比如你想从最后一层获取特征,那么你只需要声明 predictions = end_points["Logits"]
如果你想为其他中间层获取它,你可以从上面的程序 inception_resnet_v2.py 中获取那些名字
之后你可以调用:output = sess.run(predictions, feed_dict={X:batch_images})
I have inception_resnet_v2_2016_08_30.ckpt
file which is a pre-trained inception model. I want to restore this model using
saver.restore(sess, ckpt_filename)
But for that, I will be required to write the set of variables that were used while training this model. Where can I find those (a script, or detailed description)?
First of you have get the network architecture in memory. You can get the network architecture from here
Once you have this program with you, use the following approach to use the model:
from inception_resnet_v2 import inception_resnet_v2, inception_resnet_v2_arg_scope
height = 299
width = 299
channels = 3
X = tf.placeholder(tf.float32, shape=[None, height, width, channels])
with slim.arg_scope(inception_resnet_v2_arg_scope()):
logits, end_points = inception_resnet_v2(X, num_classes=1001,is_training=False)
With this you have all the network in memory, Now you can initialize the network with checkpoint file(ckpt) by using tf.train.saver:
saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess, "/home/pramod/Downloads/inception_resnet_v2_2016_08_30.ckpt")
If you want to do bottle extractions, its simple like lets say you want to get features from last layer, then simply you have to declare predictions = end_points["Logits"]
If you want to get it for other intermediate layer, you can get those names from the above program inception_resnet_v2.py
After that you can call: output = sess.run(predictions, feed_dict={X:batch_images})
这篇关于如何恢复 tensorflow inceptions 检查点文件(ckpt)?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!