最近工作里需要用到tensorflow的pretrained-model去做retrain. 记录一下.

为什么可以用pretrained-model去做retrain

这个就要引出CNN的本质了.CNN的本质就是求出合适的卷积核,提取出合理的底层特征.进而为不同的特征赋以权重.从而表达图像.

通俗点讲,比如有一张猫的图片,你怎么判断是猫不是狗?你可能会看到图里有猫的头,猫的爪子,猫的尾巴. 头/爪子/尾巴 就是CNN中比较靠前的层所提取出来的特征,我们称之为高级特征,这时候的特征我们人类还是能理解的. 继续对这些头/爪子/尾巴继续做特征提取,...,最终得到的特征已经非常细节非常抽象了,可能是一个点,一条线等等. 最终我们的image=这些低级特征乘以不同权重,求和.

假设现在你有一个基于公开数据集的trained-model.这个数据集里没有你想识别的图片,比如红绿灯吧. 但是,没关系!!,虽然你之前的模型不认识红绿灯,但是它也抽象出来了很多底层的抽象的细节特征啊,点啊,线啊之类的. 我们依然可以使用这些特征去表示红绿灯图片,只是每个特征的权重要改变而已! 这就是所谓的增强学习.

tensorflow里存储"很多底层的抽象的细节特征啊,点啊,线啊之类的"文件,称之为module.更多详细的见https://www.tensorflow.org/hub/tutorials/image_retraining

环境准备

  • conda activate venv_python3.6
  • pip install "tensorflow>=1.7.0"
  • pip install tensorflow-hub

数据准备

示例代码下载

重训练

  • python retrain.py --image_dir ~/flower_photos

训练相关的文件模型等存储于/tmp

  • /tmp/bottleneck 可以理解为每一个图片的feature map 存储的是新的class的image的抽象特征
  • /tmp/output_graph.pb 新的模型
  • /tmp/output_labels.txt 新识别出的label

bottleneck可以理解为image feature vector.可以理解为各种抽象的特征,点啊直线啊折线啊,利用这些特征,模型可以去做分类.

  • training accuracy 训练集精度
  • validation accuracy 验证集精度
  • Cross entropy 交叉熵

整体而言,cross entropy应该是不断减小的,中间可能会有小的波动

train.py

python retrain.py \
--image_dir ~/flower_photos \
--tfhub_module https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/2
  1. 会从url + '?tf-hub-format=compressed'下载module包.默认会下载到/tmp/tfhub_modules
tar -xvf ../module.tar ./
./
./saved_model.pb
./variables/
./variables/variables.index
./variables/variables.data-00000-of-00001
./assets/
./tfhub_module.pb

这里面就包含了抽象的底层特征.

ssd module下载

https://tfhub.dev/google/openimages_v4/ssd/mobilenet_v2/1

数据集结构

tensorflow用pretrained-model做retrain-LMLPHP

每个目录下是相应类别的jpg文件

数据集的搜集应当注意的几点问题

如何使用本地model做retrain

这一步还没成功,因为我的需求比较特殊,我需要在jetson nano上跑模型,而tensorrt目前还是有Bug的,不是什么model都能推理,有的model里的算子不支持.而从tensorflow的官网download的ssd model的module,做retrain后得到的model无法在jetson nano上推理,

目前我需要ssd_inception_v2_coco_2017_11_17这个model对应的module,很不幸,并没有,只能自己写代码去做转换,使用了官方的create_module_spec_from_saved_model api还是有问题

与此问题相关的link

https://github.com/tensorflow/hub/issues/37

https://github.com/tensorflow/hub/blob/52d5066e925d345fbd54ddf98b7cadf027b69d99/examples/image_retraining/retrain.py 对应分支

https://www.tensorflow.org/hub/creating

python retrain.py

--image_dir ~/flower_photos

--tfhub_module ./ssd_inception_v2_coco_2017_11_17

tensorflow文件含义

  • .pb文件 存储了完整的模型的结构信息,变量信息等.
  • checkpoint文件 记录模型路径信息
cat checkpoint
model_checkpoint_path: "/tmp/_retrain_checkpoint"
all_model_checkpoint_paths: "/tmp/_retrain_checkpoint"
  • .meta文件存储了运算图的结构
  • .index文件存储了tensor结构的信息,ensorname<-->BundleEntryProto
  • .data文件存储所有变量的值
05-04 07:40