YOLOv4部署-keras权重转tflite (h5 weights -> pb -> tflite)

一、相关环境介绍

1、网络训练的环境

  • tensorflow-gpu==1.4.0
  • Keras==2.1.5
  • python==3.6.2

2、h5->pb->tflite的环境

  • tensorflow==2.2.0
  • python==3.6.2

二、转换过程

1、加载h5权重,并转换为pb

  • 修改yolov4网络结构中keras的导包命令,tensorflow=2.2.0中自带的有keras

    # 将所有的
    from keras import *
    from keras.* import *
    # 修改过为:
    from tensorflow.keras import *
    from tensorflow.keras.* import *
    
  • 加载权重和网络结构

    from tensorflow.keras.models import Model
    import tensorflow as tf
    
    class YOLO(object):
        _defaults = {
            ...
        }
    
        @classmethod
        def get_defaults(cls, n):
            if n in cls._defaults:
                return cls._defaults[n]
            else:
                return "Unrecognized attribute name '" + n + "'"
    
        #---------------------------------------------------#
        #   初始化yolo
        #---------------------------------------------------#
        def __init__(self, **kwargs):
            self.__dict__.update(self._defaults)
            for name, value in kwargs.items():
                setattr(self, name, value)
                self._defaults[name] = value 
            self.class_names, self.num_classes = get_classes(self.classes_path)
            self.anchors, self.num_anchors     = get_anchors(self.anchors_path)
            self.generate()
    
        #---------------------------------------------------#
        #   载入模型
        #---------------------------------------------------#
        def generate(self):
            model_path = os.path.expanduser(self.model_path)
            assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.'
            self.model = yolo_body([640, 640, 3], ...)
            self.model.load_weights(self.model_path)
            outputs = Lambda(
                DecodeBox, 
                output_shape = (1,), 
                name = 'yolo_eval',
                arguments = {...}
            )(self.model.output)  # self.model.output输出是三个特征层
            self.yolo_model = Model(self.model.input, outputs)
            tf.saved_model.save(self.yolo_model, "yolo_tflite/yolov4")
    

2、pb文件转换为tflite

 converter = tf.lite.TFLiteConverter.from_saved_model('yolo_tflite/yolov4')
 converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
 converter.allow_custom_ops = True
 tflite_model = converter.convert()
 open('yolo_tflite/yolo_fp32.tflite', 'wb').write(tflite_model)

你可能感兴趣的:(TFLite,keras,tensorflow,python)