将Keras模型导出到

将Keras模型导出到

本文介绍了将Keras模型导出到.pb文件并优化推理以在Android上提供随机猜测的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在开发一个用于年龄和性别识别的android应用程序.我在GitHub中找到了一个有用的模型.他们正在基于第一放置获奖纸.他们提供了用于训练和构建网络的python模块,已经训练好的权重文件以供下载和使用以及网络摄像头上的有效演示.

I am developing an android application for the purpose of age and gender recognition. I have found a useful model in GitHub. They are building a Keras model (tensorflow backend) based on a first-place winning paper. They have provided python modules to train and build the network, already trained weights file to download and work with, and a working demo on web cam.

我想在演示中将提供的权重转换为.pb文件,以便它也可以在android上执行.

I want to convert their model, in the demo, with the provided weights to .pb file so that it is executable on android as well.

我使用此代码进行了与模型相关的较小修改:

I used this code to convert with minor model-dependent modifications:

from keras.models import Sequential
from keras.models import model_from_json
from keras import backend as K
import tensorflow as tf
from tensorflow.python.tools import freeze_graph
import os

# Load existing model.
with open("model.json",'r') as f:
    modelJSON = f.read()

model = model_from_json(modelJSON)
model.load_weights("weights.18-4.06.hdf5")
print(model.summary())

# All new operations will be in test mode from now on.
K.set_learning_phase(0)

# Serialize the model and get its weights, for quick re-building.
config = model.get_config()
weights = model.get_weights()

# Re-build a model where the learning phase is now hard-coded to 0.
#new_model = model.from_config(config)
#new_model.set_weights(weights)

temp_dir = "graph"
checkpoint_prefix = os.path.join(temp_dir, "saved_checkpoint")
checkpoint_state_name = "checkpoint_state"
input_graph_name = "input_graph.pb"
output_graph_name = "output_graph.pb"

# Temporary save graph to disk without weights included.
saver = tf.train.Saver()
checkpoint_path = saver.save(K.get_session(), checkpoint_prefix, global_step=0, latest_filename=checkpoint_state_name)
tf.train.write_graph(K.get_session().graph, temp_dir, input_graph_name)

input_graph_path = os.path.join(temp_dir, input_graph_name)
input_saver_def_path = ""
input_binary = False
output_node_names = "dense_1/Softmax,dense_2/Softmax" # model dependent
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_graph_path = os.path.join(temp_dir, output_graph_name)
clear_devices = False

# Embed weights inside the graph and save to disk.
freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
                          input_binary, checkpoint_path,
                          output_node_names, restore_op_name,
                          filename_tensor_name, output_graph_path,
                          clear_devices, "")

我直接在演示中生成了 model.json 文件.带有model.json的demo.py文件的主要功能代码为:

I produced the model.json file fro the demo directly. The code of the main function of demo.py file with the model.json is:

def main():
    args = get_args()
    depth = args.depth
    k = args.width
    weight_file = args.weight_file

    if not weight_file:
        weight_file = get_file("weights.18-4.06.hdf5", pretrained_model, cache_subdir="pretrained_models",
                               file_hash=modhash, cache_dir=os.path.dirname(os.path.abspath(__file__)))

    # for face detection
    detector = dlib.get_frontal_face_detector()

    # load model and weights
    img_size = 64
    model = WideResNet(img_size, depth=depth, k=k)()
    model.load_weights(weight_file)
    print(model.summary())

    # write model to json
    model_json = model.to_json()
    with open("model.json", "w") as json_file:
        json_file.write(model_json)

    for img in yield_images():
        input_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img_h, img_w, _ = np.shape(input_img)

        # detect faces using dlib detector
        detected = detector(input_img, 1)
        faces = np.empty((len(detected), img_size, img_size, 3))

        if len(detected) > 0:
            for i, d in enumerate(detected):
                x1, y1, x2, y2, w, h = d.left(), d.top(), d.right() + 1, d.bottom() + 1, d.width(), d.height()
                xw1 = max(int(x1 - 0.4 * w), 0)
                yw1 = max(int(y1 - 0.4 * h), 0)
                xw2 = min(int(x2 + 0.4 * w), img_w - 1)
                yw2 = min(int(y2 + 0.4 * h), img_h - 1)
                cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)
                # cv2.rectangle(img, (xw1, yw1), (xw2, yw2), (255, 0, 0), 2)
                faces[i, :, :, :] = cv2.resize(img[yw1:yw2 + 1, xw1:xw2 + 1, :], (img_size, img_size))

            # predict ages and genders of the detected faces
            results = model.predict(faces)
            predicted_genders = results[0]
            ages = np.arange(0, 101).reshape(101, 1)
            predicted_ages = results[1].dot(ages).flatten()

            # draw results
            for i, d in enumerate(detected):
                label = "{}, {}".format(int(predicted_ages[i]),
                                        "F" if predicted_genders[i][0] > 0.5 else "M")
                draw_label(img, (d.left(), d.top()), label)

        cv2.imshow("result", img)
        key = cv2.waitKey(30)

        if key == 27:
            break


if __name__ == '__main__':
    main()

代码成功编译并生成多个检查点文件以及一个.pb文件.

The code successfully compiles and produces multiple checkpoint files along side with a .pb file.

这是模型的图形摘要:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            (None, 64, 64, 3)    0
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 64, 64, 16)   432         input_1[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 64, 64, 16)   64          conv2d_1[0][0]
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 64, 64, 16)   0           batch_normalization_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 64, 64, 128)  18432       activation_1[0][0]
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 64, 64, 128)  512         conv2d_2[0][0]
__________________________________________________________________________________________________
activation_2 (Activation)       (None, 64, 64, 128)  0           batch_normalization_2[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 64, 64, 128)  147456      activation_2[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 64, 64, 128)  2048        activation_1[0][0]
__________________________________________________________________________________________________
add_1 (Add)                     (None, 64, 64, 128)  0           conv2d_3[0][0]
                                                                 conv2d_4[0][0]
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 64, 64, 128)  512         add_1[0][0]
__________________________________________________________________________________________________
activation_3 (Activation)       (None, 64, 64, 128)  0           batch_normalization_3[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 64, 64, 128)  147456      activation_3[0][0]
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 64, 64, 128)  512         conv2d_5[0][0]
__________________________________________________________________________________________________
activation_4 (Activation)       (None, 64, 64, 128)  0           batch_normalization_4[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 64, 64, 128)  147456      activation_4[0][0]
__________________________________________________________________________________________________
add_2 (Add)                     (None, 64, 64, 128)  0           conv2d_6[0][0]
                                                                 add_1[0][0]
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 64, 64, 128)  512         add_2[0][0]
__________________________________________________________________________________________________
activation_5 (Activation)       (None, 64, 64, 128)  0           batch_normalization_5[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 32, 32, 256)  294912      activation_5[0][0]
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 32, 32, 256)  1024        conv2d_7[0][0]
__________________________________________________________________________________________________
activation_6 (Activation)       (None, 32, 32, 256)  0           batch_normalization_6[0][0]
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 32, 32, 256)  589824      activation_6[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 32, 32, 256)  32768       activation_5[0][0]
__________________________________________________________________________________________________
add_3 (Add)                     (None, 32, 32, 256)  0           conv2d_8[0][0]
                                                                 conv2d_9[0][0]
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 32, 32, 256)  1024        add_3[0][0]
__________________________________________________________________________________________________
activation_7 (Activation)       (None, 32, 32, 256)  0           batch_normalization_7[0][0]
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 32, 32, 256)  589824      activation_7[0][0]
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 32, 32, 256)  1024        conv2d_10[0][0]
__________________________________________________________________________________________________
activation_8 (Activation)       (None, 32, 32, 256)  0           batch_normalization_8[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 32, 32, 256)  589824      activation_8[0][0]
__________________________________________________________________________________________________
add_4 (Add)                     (None, 32, 32, 256)  0           conv2d_11[0][0]
                                                                 add_3[0][0]
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 32, 32, 256)  1024        add_4[0][0]
__________________________________________________________________________________________________
activation_9 (Activation)       (None, 32, 32, 256)  0           batch_normalization_9[0][0]
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 16, 16, 512)  1179648     activation_9[0][0]
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 16, 16, 512)  2048        conv2d_12[0][0]
__________________________________________________________________________________________________
activation_10 (Activation)      (None, 16, 16, 512)  0           batch_normalization_10[0][0]
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 16, 16, 512)  2359296     activation_10[0][0]
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 16, 16, 512)  131072      activation_9[0][0]
__________________________________________________________________________________________________
add_5 (Add)                     (None, 16, 16, 512)  0           conv2d_13[0][0]
                                                                 conv2d_14[0][0]
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 16, 16, 512)  2048        add_5[0][0]
__________________________________________________________________________________________________
activation_11 (Activation)      (None, 16, 16, 512)  0           batch_normalization_11[0][0]
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 16, 16, 512)  2359296     activation_11[0][0]
__________________________________________________________________________________________________
batch_normalization_12 (BatchNo (None, 16, 16, 512)  2048        conv2d_15[0][0]
__________________________________________________________________________________________________
activation_12 (Activation)      (None, 16, 16, 512)  0           batch_normalization_12[0][0]
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 16, 16, 512)  2359296     activation_12[0][0]
__________________________________________________________________________________________________
add_6 (Add)                     (None, 16, 16, 512)  0           conv2d_16[0][0]
                                                                 add_5[0][0]
__________________________________________________________________________________________________
batch_normalization_13 (BatchNo (None, 16, 16, 512)  2048        add_6[0][0]
__________________________________________________________________________________________________
activation_13 (Activation)      (None, 16, 16, 512)  0           batch_normalization_13[0][0]
__________________________________________________________________________________________________
average_pooling2d_1 (AveragePoo (None, 16, 16, 512)  0           activation_13[0][0]
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 131072)       0           average_pooling2d_1[0][0]
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 2)            262144      flatten_1[0][0]
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 101)          13238272    flatten_1[0][0]
==================================================================================================
Total params: 24,463,856
Trainable params: 24,456,656
Non-trainable params: 7,200
__________________________________________________________________________________________________

我采用了输出的模型,并使用以下脚本优化了推断:

I took the outputted model and used the following script to optimize for inferece:

python -m tensorflow.python.tools.optimize_for_inference --input output_graph.pb --output g.pb --input_names=input_1 --output_names=dense_1/Softmax,dense_2/Softmax

在操作过程中,终端会给我很多这样的警告.

during the operation, the terminal gives me many warnings like this.

 FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
WARNING:tensorflow:Incorrect shape for mean, found (0,), expected (16,), for node batch_normalization_1/FusedBatchNorm
WARNING:tensorflow:Incorrect shape for mean, found (0,), expected (128,), for node batch_normalization_2/FusedBatchNorm
WARNING:tensorflow:Didn't find expected Conv2D input to 'batch_normalization_3/FusedBatchNorm'
WARNING:tensorflow:Incorrect shape for mean, found (0,), expected (128,), for node batch_normalization_4/FusedBatchNorm
WARNING:tensorflow:Didn't find expected Conv2D input to 'batch_normalization_5/FusedBatchNorm'
WARNING:tensorflow:Incorrect shape for mean, found (0,), expected (256,), for node batch_normalization_6/FusedBatchNorm
WARNING:tensorflow:Didn't find expected Conv2D input to 'batch_normalization_7/FusedBatchNorm'
WARNING:tensorflow:Incorrect shape for mean, found (0,), expected (256,), for node batch_normalization_8/FusedBatchNorm
WARNING:tensorflow:Didn't find expected Conv2D input to 'batch_normalization_9/FusedBatchNorm'
WARNING:tensorflow:Incorrect shape for mean, found (0,), expected (512,), for node batch_normalization_10/FusedBatchNorm
WARNING:tensorflow:Didn't find expected Conv2D input to 'batch_normalization_11/FusedBatchNorm'
WARNING:tensorflow:Incorrect shape for mean, found (0,), expected (512,), for node batch_normalization_12/FusedBatchNorm
WARNING:tensorflow:Didn't find expected Conv2D input to 'batch_normalization_13/FusedBatchNorm'

这些警告似乎很糟糕!

我已经在Android应用程序上尝试了这两个文件.未优化的文件可执行时,优化的文件根本无法工作,但会产生无意义的结果"例如GUESSING ".

I have tried both files on my android app. The optimized file is not working at all while the non-optimized file is executable but producing non-sense results "e.g. GUESSING".

我知道这个问题有点长,但这是整个工作日的总结,我不想错过任何细节.

I know that the question is a little bit long but it is a summary of whole working day and I do not want to miss any fraction of details.

我不知道问题出在哪里.是在输出节点名称中,冻结图,使用权重实例化模型还是在优化推理脚本中.

推荐答案

经过研究,随机猜测的问题终于得到解决.

After a research, the problem of random guessing was finally resolved.

问题不是像我最初期望的那样将模型转换为.pb文件,而是关于将图像正确地馈送到Android中的模型.

The problem was not about converting the model into .pb file as I first expected, but about feeding the image to the model in Android correctly.

我再次尝试转换模型.以下几点将概述我的工作.

I worked on converting the model again. The following points will summarize my work.

  • 首先,我从问题中上面引用的demo.py获取了模型.我使用以下代码保存了该模型:# save the model to .h5 file.model.save('./saved_model/model.h5')
  • 其次,我获取了.h5生成的文件并将其转换为.pb文件.我使用了此存储库.如果您无法通过超链接访问该链接,请: https://github.com/amir- abdi/keras_to_tensorflow .该存储库的代码证明了其可靠性.它将模型转换为.pb文件,并立即对其进行优化以进行推理.太神奇了!

  • First, I got the model from demo.py referenced above in the question. I saved it using the following code:# save the model to .h5 file.model.save('./saved_model/model.h5')
  • Second, I took the .h5 generated file and convert it to .pb file. I used the code in this repository. The link, in case you could not reach it in the hyperlink,: https://github.com/amir-abdi/keras_to_tensorflow. The code of this repository proves its reliability. It converts the model to .pb file and optimize it for inference at once. It is amazing!

第三,我将生成的.pb文件带到android资产文件夹中,以便使用我的应用程序对其进行配置.

Third, I took the generated .pb file to android assets folder in order to configure it with my application.

第四,我将目标图像转换为像素值,并进行了逐位移位以提取颜色.获得有关此代码的帮助以完成此任务.请记住,getPixels方法保留颜色通道.因此,如果需要反转颜色通道,请遵循以下代码.我从此答案.

Fourth, I converted the intended image to pixel values and did bit-wise shifting to extract colors. getting help with this code to complete this task. Keep in mind that getPixels method preserves colors channels. So, If you need to inverse the color channels follow the following code. I got this help from this answer.

    Bitmap bitmap = createScaledBitmap(faces[0], INPUT_SIZE , INPUT_SIZE , true);
    // get pixel values
    bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());

    for (int i = 0; i < intValues.length; ++i) {

        final int val = intValues[i];

        // extract colors using bit-wise shifting.
        floatValues[i * 3 + 0] = ((val >> 16) & 0xFF );
        floatValues[i * 3 + 1] = ((val >> 8) & 0xFF );
        floatValues[i * 3 + 2] = (val & 0xFF );

        // reverse the color orderings.
        floatValues[i*3 + 2] = Color.red(val);
        floatValues[i*3 + 1] = Color.green(val);
        floatValues[i*3] = Color.blue(val);
    }

  • 最后,我可以使用张量流推断方法将图像输入模型,进行推断并输出结果.

  • Finally, I can use tensor-flow inference methods to feed the image to the model, infer, and output the results.

    这篇关于将Keras模型导出到.pb文件并优化推理以在Android上提供随机猜测的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

  • 08-29 03:11