问题描述
我一直在试验新的 TensorFlow 中可用的 8 位量化功能.我可以在没有任何问题的情况下运行博客文章中给出的示例(googlenet 的量化),它对我来说很好用!!!
I have been experimenting with the new 8-bit quantization feature available in TensorFlow. I could run the example given in the blog post (quantization of googlenet) without any issue and it works fine for me !!!
现在,我想将相同的内容应用于更简单的网络.所以我使用了 CIFAR-10 的预训练网络(在 Caffe 上训练),提取其参数,在 tensorflow 中创建相应的图,用这个预训练的权重初始化权重,最后将其保存为 GraphDef 对象.请参阅此 IPython Notebook 了解完整过程.
Now, I would like to apply the same for a simpler network. So I used a pre-trained network for CIFAR-10 (which is trained on Caffe), extracted its parameters, created corresponding graph in tensorflow, initialized the weights with this pre-trained weights and finally saved it as a GraphDef object. See this IPython Notebook for full procedure.
现在我使用 tensorflow 脚本应用了 8 位量化,如 Pete Warden 的博客中所述:
Now I applied the 8-bit quantization with the tensorflow script as mentioned in the Pete Warden's blog:
bazel-bin/tensorflow/contrib/quantization/tools/quantize_graph --input=cifar.pb --output=qcifar.pb --mode=eightbit --bitdepth=8 --output_node_names="ArgMax"
现在我想在这个量化网络上运行分类.所以我将新的 qcifar.pb
加载到 tensorflow 会话并传递图像(与我将其传递给原始版本的方式相同).完整代码可在此 IPython Notebook 中找到.
Now I wanted to run the classification on this quantized network. So I loaded the new qcifar.pb
to a tensorflow session and passed the image (the same way I passed it to original version). Full code can be found in this IPython Notebook.
但正如你在最后看到的,我收到以下错误:
But as you can see at the end, I am getting following error:
NotFoundError:操作类型未注册QuantizeV2"
有人可以建议我在这里缺少什么吗?
Can anybody suggest what am I missing here?
推荐答案
由于量化操作和内核在 contrib 中,您需要在 Python 脚本中显式加载它们.在 quantize_graph 中有一个例子.py 脚本本身:
Because the quantized ops and kernels are in contrib, you'll need to explicitly load them in your python script. There's an example of that in the quantize_graph.py script itself:
from tensorflow.contrib.quantization import load_quantized_ops_so从 tensorflow.contrib.quantization.kernels 导入 load_quantized_kernels_so
这是我们应该更新文档以提及的内容!
This is something that we should update the documentation to mention!
这篇关于Tensorflow 中的 8 位量化错误的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!