最近工作里需要用到tensorflow的pretrained-model去做retrain. 记录一下.
为什么可以用pretrained-model去做retrain
这个就要引出CNN的本质了.CNN的本质就是求出合适的卷积核,提取出合理的底层特征.进而为不同的特征赋以权重.从而表达图像.
通俗点讲,比如有一张猫的图片,你怎么判断是猫不是狗?你可能会看到图里有猫的头,猫的爪子,猫的尾巴. 头/爪子/尾巴 就是CNN中比较靠前的层所提取出来的特征,我们称之为高级特征,这时候的特征我们人类还是能理解的. 继续对这些头/爪子/尾巴继续做特征提取,...,最终得到的特征已经非常细节非常抽象了,可能是一个点,一条线等等. 最终我们的image=这些低级特征乘以不同权重,求和.
假设现在你有一个基于公开数据集的trained-model.这个数据集里没有你想识别的图片,比如红绿灯吧. 但是,没关系!!,虽然你之前的模型不认识红绿灯,但是它也抽象出来了很多底层的抽象的细节特征啊,点啊,线啊之类的. 我们依然可以使用这些特征去表示红绿灯图片,只是每个特征的权重要改变而已! 这就是所谓的增强学习.
tensorflow里存储"很多底层的抽象的细节特征啊,点啊,线啊之类的"文件,称之为module.更多详细的见https://www.tensorflow.org/hub/tutorials/image_retraining
环境准备
- conda activate venv_python3.6
- pip install "tensorflow>=1.7.0"
- pip install tensorflow-hub
数据准备
- cd ~
- curl -LO http://download.tensorflow.org/example_images/flower_photos.tgz
- tar xzf flower_photos.tgz
示例代码下载
- mkdir ~/example_code
- cd ~/example_code
- curl -LO https://github.com/tensorflow/hub/raw/master/examples/image_retraining/retrain.py
重训练
- python retrain.py --image_dir ~/flower_photos
训练相关的文件模型等存储于/tmp
- /tmp/bottleneck 可以理解为每一个图片的feature map 存储的是新的class的image的抽象特征
- /tmp/output_graph.pb 新的模型
- /tmp/output_labels.txt 新识别出的label
bottleneck可以理解为image feature vector.可以理解为各种抽象的特征,点啊直线啊折线啊,利用这些特征,模型可以去做分类.
- training accuracy 训练集精度
- validation accuracy 验证集精度
- Cross entropy 交叉熵
整体而言,cross entropy应该是不断减小的,中间可能会有小的波动
train.py
python retrain.py \
--image_dir ~/flower_photos \
--tfhub_module https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/2
- 会从url + '?tf-hub-format=compressed'下载module包.默认会下载到/tmp/tfhub_modules
tar -xvf ../module.tar ./
./
./saved_model.pb
./variables/
./variables/variables.index
./variables/variables.data-00000-of-00001
./assets/
./tfhub_module.pb
这里面就包含了抽象的底层特征.
ssd module下载
https://tfhub.dev/google/openimages_v4/ssd/mobilenet_v2/1
数据集结构
每个目录下是相应类别的jpg文件
数据集的搜集应当注意的几点问题
如何使用本地model做retrain
这一步还没成功,因为我的需求比较特殊,我需要在jetson nano上跑模型,而tensorrt目前还是有Bug的,不是什么model都能推理,有的model里的算子不支持.而从tensorflow的官网download的ssd model的module,做retrain后得到的model无法在jetson nano上推理,
目前我需要ssd_inception_v2_coco_2017_11_17这个model对应的module,很不幸,并没有,只能自己写代码去做转换,使用了官方的create_module_spec_from_saved_model api还是有问题
与此问题相关的link
https://github.com/tensorflow/hub/issues/37
https://github.com/tensorflow/hub/blob/52d5066e925d345fbd54ddf98b7cadf027b69d99/examples/image_retraining/retrain.py 对应分支
https://www.tensorflow.org/hub/creating
python retrain.py
--image_dir ~/flower_photos
--tfhub_module ./ssd_inception_v2_coco_2017_11_17
tensorflow文件含义
- .pb文件 存储了完整的模型的结构信息,变量信息等.
- checkpoint文件 记录模型路径信息
cat checkpoint
model_checkpoint_path: "/tmp/_retrain_checkpoint"
all_model_checkpoint_paths: "/tmp/_retrain_checkpoint"
- .meta文件存储了运算图的结构
- .index文件存储了tensor结构的信息,ensorname<-->BundleEntryProto
- .data文件存储所有变量的值