利用tensorflow搭建Alexnet,并进行.tflite的导出

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档


前言

        TensorFlow Lite是一种用于设备端推断的开源深度学习框架,可帮助开发者在移动设备、嵌入式设备和IoT设备上运行TensorFlow模型,本篇文章将简要介绍利用tensorflow搭建Alexnet网络,并进行.tflite文件导出的过程,为后续在移动设备、嵌入式设备中使用.tflite文件进行后续开发做准备。


一、搭建Alexnet网络

        使用tf中的keras API进行网络的搭建,tf中的keras API与keras存在一定的区别,tf作为一种静态计算流,在训练阶段具有较高的效率,但是对于用户而言,进行网络的搭建、修改等工作时,可使用性较低,kerase最初是谷歌的一位研究人员为了方便自己的研究而开发的一款简化深度学习操作的封装,允许工作人员灵活的构建网络,可用性较强,随后tf开发了针对于kerase的高级API,两者在功能上类似,但是在使用过程中要注意区分,否则有可能引发异常。

from tensorflow.keras import layers, models, Model, Sequential
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
import os
from PIL import Image
import numpy as np
import json
import matplotlib.pyplot as plt

def AlexNet(im_height=224, im_width=224, class_num=3):  #class_num需要根据实际识别类别设置,宽高固定为224
    # tensorflow中的tensor通道排序是NHWC
    input_image = layers.Input(shape=(im_height, im_width, 3), dtype="float32") 
    x = layers.ZeroPadding2D(((1, 2), (1, 2)))(input_image) 
    x = layers.Conv2D(48, kernel_size=11, strides=4, activation="relu")(x) 
    x = layers.MaxPool2D(pool_size=3, strides=2)(x) 
    x = layers.Conv2D(128, kernel_size=5, padding="same", activation="relu")(x)
    x = layers.MaxPool2D(pool_size=3, strides=2)(x)                         
    x = layers.Conv2D(192, kernel_size=3, padding="same", activation="relu")(x) 
    x = layers.Conv2D(192, kernel_size=3, padding="same", activation="relu")(x) 
    x = layers.Conv2D(128, kernel_size=3, padding="same", activation="relu")(x)  
    x = layers.MaxPool2D(pool_size=3, strides=2)(x) 
    x = layers.Flatten()(x)         
    x = layers.Dropout(0.2)(x)
    x = layers.Dense(2048, activation="relu")(x) 
    x = layers.Dropout(0.2)(x)
    x = layers.Dense(2048, activation="relu")(x)
    x = layers.Dense(class_num)(x)        
    predict = layers.Softmax()(x)
    model = models.Model(inputs=input_image, outputs=predict)
    return model

###################################################
########  train  训练数据集按照以下文档结构设置
#####     data_root 
#####         train
####              label_1 
####                   img1
####                   img2
####                   ...
####              label_2 
####                   img1
####                   img2
####                   ...
####              ...
#####         val
####              label_1 
####                   img1
####                   img2
####                   ...
####              label_2 
####                   img1
####                   img2
####                   ...
####              ...
###########################################
def train(image_path, im_height=224, im_width=224, batch_size=3, epochs=10):
    train_dir = image_path + "train"
    validation_dir = image_path + "val"
    if not os.path.exists("save_weights"):
        os.makedirs("save_weights")#在指定的路径下创建文件夹 用来保存训练模型的权重

    #keras模块提供的图片生成器:可以载入文件夹下的图片生成器并对其进行预处理
    train_image_generator = ImageDataGenerator(rescale=1. / 255,
                                               horizontal_flip=True)
    validation_image_generator = ImageDataGenerator(rescale=1. / 255)
    train_data_gen = train_image_generator.flow_from_directory(directory=train_dir,#训练集目录
                                                               batch_size=batch_size,
                                                               shuffle=True,#是否打乱图片的相对顺序
                                                               target_size=(im_height, im_width),#输入图片尺寸
                                                               class_mode='categorical')#分类的方式
    total_train = train_data_gen.n #获得训练集训练样本的个数

    # get class dict
    class_indices = train_data_gen.class_indices#字典类型,返回每个类别和其索引
    print(class_indices)  #查看字典,确保信息正确性
    # 将key和value进行反转 得到反过来的字典 (目的:在预测的过程中通过索引直接对应到类别中)
    inverse_dict = dict((val, key) for key, val in class_indices.items())
    json_str = json.dumps(inverse_dict, indent=4)#将图像类别及其索引保存下来,以便后续测试过程的使用
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)
    val_data_gen = validation_image_generator.flow_from_directory(directory=validation_dir,
                                                                  batch_size=batch_size,
                                                                  shuffle=False,
                                                                  target_size=(im_height, im_width),                                                      class_mode='categorical')
    total_val = val_data_gen.n

    model = AlexNet(im_height=im_height, im_width=im_width, class_num=3)#实例化网络
    model.summary()#输出模型的参数信息

    # using keras high level api for training
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0005),
                  loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
                  metrics=["accuracy"])
    #定义回调函数(保存模型的一些规则)的列表
    callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath='./save_weights/myAlex.h5',#保存模型的位置:当前文件夹下
                                                    save_best_only=True,#是否保存最佳参数 还是保存最后的训练参数
                                                    save_weights_only=False,#是否只保存权重 如果不止权重文件还有模型文件,这样就不需要创建网络直接调用模型文件即可
                                                    monitor='val_loss')]#所监控的参数:验证集的损失 判断是不是最佳,变小的话模型效果就会变好

    # 训练过程的一些信息保存在history中
    history = model.fit(x=train_data_gen,#训练集生成器
                        steps_per_epoch=total_train // batch_size,#每一轮要迭代多少次即一个epoch要迭代多少次 //是除
                        epochs=epochs,#迭代多少轮
                        validation_data=val_data_gen,#给定验证集生成器
                        validation_steps=total_val // batch_size,#验证集的时候没有dropout fit方法自动实现了
                        callbacks=callbacks)#

    # plot loss and accuracy image
    history_dict = history.history#通过这样的方法可以获取到数据字典 保存了训练集的损失和准确率,验证集的损失和准确率
    train_loss = history_dict["loss"]
    train_accuracy = history_dict["acc"]
    val_loss = history_dict["val_loss"]
    val_accuracy = history_dict["val_acc"]

    history = model.fit_generator(generator=train_data_gen,
                                  steps_per_epoch=total_train // batch_size,
                                  epochs=epochs,
                                  validation_data=val_data_gen,
                                  validation_steps=total_val // batch_size,
                                  callbacks=callbacks)

#########################################################
## test
########################################################
def test(img_dir, im_height = 224, im_width = 224):
    # load image
    img = Image.open(img_dir)
    # resize image to 224x224
    img = img.resize((im_width, im_height))

    # scaling pixel value to (0-1)  与训练过程的图像预处理保持一致
    img = np.array(img) / 255.

    # Add the image to a batch where it's the only member. 扩充图片维度,输入到网络中必须是(batch 宽 高 深)
    img = (np.expand_dims(img, 0))

    # read class_indict
    try:
        json_file = open('./class_indices.json', 'r')#读取之前保存好的json文件
        class_indict = json.load(json_file)#对应的类别信息
    except Exception as e:
        print(e)
        exit(-1)

    model = AlexNet(class_num=3)
    model.load_weights("./save_weights/myAlex.h5")#载入模型
    result = np.squeeze(model.predict(img))#进行预测得到的结果有batch维度,用squeeze压缩
    predict_class = np.argmax(result)#获取概率最大的值所对应的索引
    print(class_indict[str(predict_class)], result[predict_class])#得到分类所属类别


二、导出.tflite文件

        在完成上面的训练和测试过程后,我们将得到保存为.h5格式的模型,在进行导出之前,请确保保存时保存了整个模型而不仅仅是模型的权重参数。

        首先,我们需要将.h5文件转换为.pb文件,在http://www.github.com/amir-abdi/keras_to_tensorflow.githttp://www.github.com/amir-abdi/keras_to_tensorflow.git在该网址下可以找到将.h5文件转换为.pb文件的程序,但是在使用过程中,我们需要将所有的keras都改写为tensorflow.keras,这在前文中提到过,尽管tf中的keras与Keras功能相近,但是要在使用过程中注意区分,否则会引发异常,因为.h5文件是用tf中的keras训练得到的,因此转换时也需要用tf中的keras,改好文件之后,运行keras_to_tensorflow.py,替换下面程序中的路径即可。

得到.pb文件之后,运行下面的代码获得模型的输入输出信息,Input_name和Output_name

def create_graph(model_path):
    with tf.gfile.FastGFile(os.path.join(model_path), 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')


def print_io_arrays(pb):
    gf = tf.GraphDef()
    m_file = open(pb, 'rb')
    gf.ParseFromString(m_file.read())

    with open('gfnode.txt', 'a') as the_file:
        for n in gf.node:
            the_file.write(n.name + '\n')

    file = open('gfnode.txt', 'r')
    data = file.readlines()
    print("output name = ")
    print(data[len(data) - 1])
    print("Input name = ")
    file.seek(0)
    print(file.readline())


pd_file_path = 'save_weights/myAlex.pb'
print_io_arrays(pd_file_path)

将Input_name和Output_name信息记录下来,之后运行

tflite_convert --output_file=tflite文件生成的路径 --graph_def_file=pb文件所在的路径 --input_arrays=Input_name --output_arrays=Output_name

ps:

tflite_convert --output_file=save_weights/myAlex.tflite --graph_def_file=save_weights/myAlex.pb --input_arrays=input_1 --output_arrays=softmax/Softmax

运行结束将会生成指定的.tflite文件,下面对该文件进行测试,

def lite_test(lite_model_file, img_path):
    interpreter = tf.lite.Interpreter(model_path=lite_model_file)
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    floating_model = input_details[0]['dtype'] == np.float32
    height = input_details[0]['shape'][1]
    width = input_details[0]['shape'][2]
    img = Image.open(img_path).resize((width, height))
    if floating_model:
        img = np.float32(img) / 255
    input_data = np.expand_dims(img, axis=0)
    interpreter.set_tensor(input_details[0]['index'], input_data)
    interpreter.invoke()
    output_data = interpreter.get_tensor(output_details[0]['index'])
    results = np.squeeze(output_data)
    # read class_indict
    try:
        json_file = open('./class_indices.json', 'r')  # 读取之前保存好的json文件
        class_indict = json.load(json_file)  # 对应的类别信息
    except Exception as e:
        print(e)
        exit(-1)
    predict_class = np.argmax(results)  # 获取概率最大的值所对应的索引
    print(class_indict[str(predict_class)], results[predict_class])  # 得到分类所属类别


lite_file_path = 'save_weights/myAlex.tflite'
img_dir = image_path + "val/person/11.jpg"
lite_test(lite_model_file=lite_file_path, img_path=img_dir)

参考

(17条消息) 使用tflite_convert命令工具将keras h5文件转换为tflite文件简易指南_CQJTU_Andy的博客-CSDN博客_h5转tfliteicon-default.png?t=M276https://blog.csdn.net/qq_42131061/article/details/106209894

你可能感兴趣的:(tensorflow,人工智能,python)