mobilenet v1.ckpt转.tflite格式

slim中的mobilenet v1训练模型生成的checkpoints,无法直接转成tflite格式,原因是tflite的BN层需要自定义操作,无法使用现成的算子。解决方法有两个:

1. 使用量化模型训练,即在使用slim训练的时候参数   --quantize=True,量化模型会把BN层转化为折叠BN层,具体可以查看源码。

2.我使用的是第二种方法,将训练模型转化为预测模型,然后再转为tflite,其中的BN层会被去掉,代码如下:

import tensorflow as tf
import os
import slim.nets.mobilenet_v1 as mobilenet_v1
import tensorflow.contrib.slim as slim
from tensorflow.python.tools import freeze_graph


def export_eval_pbtxt(MODEL_SAVE_PATH):
    """Export eval.pbtxt."""
    with tf.Graph().as_default() as g:
        images = tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3], name='input')
        # is_training=False会把BN层去掉
        with slim.arg_scope(mobilenet_v1.mobilenet_v1_arg_scope(is_training=False, regularize_depthwise=True)):
            _, _ = mobilenet_v1.mobilenet_v1(inputs=images, is_training=False, depth_multiplier=1.0, num_classes=7)

        saver = tf.train.Saver(max_to_keep=5)
        pb_dir = os.path.join(MODEL_SAVE_PATH, 'pb_model')

        graph_file = os.path.join(MODEL_SAVE_PATH, 'pb_model', 'mobilenet_v1_eval.pbtxt')
        
        checkpoint = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
        frozen_model = os.path.join(MODEL_SAVE_PATH, 'pb_model', 'frozen_model.pb')
        
        with tf.Session() as sess:
            if checkpoint and checkpoint.model_checkpoint_path:
                try:
                    saver.restore(sess, checkpoint.model_checkpoint_path)
                    print("Successfully loaded:", checkpoint.model_checkpoint_path)
                except:
                    print("Error on loading old network weights")
            else:
                print("Could not find old network weights")

            print('Learning Started!')
            with open(graph_file, 'w') as f:
                f.write(str(g.as_graph_def()))
            freeze_graph.freeze_graph(graph_file,
                                      '',
                                      False,
                                      checkpoint.model_checkpoint_path,
                                      "MobilenetV1/Predictions/Softmax",
                                      'save/restore_all',
                                      'save/Const:0',
                                      frozen_model,
                                      True,
                                      "")


def pb_to_tflite(input_name, output_name):
    graph_def_file = os.path.join(MODEL_SAVE_PATH, 'pb_model', 'frozen_model.pb')
    input_arrays = [input_name]
    output_arrays = [output_name]

    converter = tf.lite.TFLiteConverter.from_frozen_graph(graph_def_file, input_arrays, output_arrays)
    tflite_model = converter.convert()
    tflite_file = os.path.join(MODEL_SAVE_PATH, 'tflite_model', 'converted_model.tflite')
    open(tflite_file, "wb").write(tflite_model)


if __name__ == '__main__':
    # 先创建文件夹pb_model和tflite_model

    input_name = "input"
    output_name = "MobilenetV1/Predictions/Softmax"
    MODEL_SAVE_PATH = 'model_v1_real'
    export_eval_pbtxt(MODEL_SAVE_PATH)
    pb_to_tflite(input_name, output_name)

 

你可能感兴趣的:(python,tensorflow,tensorflow,mobilenet,tflite,量化模型,tensorflow,lite)